File size: 10,050 Bytes
1da5d75
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: mini_leaderboard"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('assets')\n", "!wget -q -O assets/__init__.py https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/__init__.py\n", "!wget -q -O assets/custom_css.css https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/custom_css.css\n", "!wget -q -O assets/leaderboard_data.json https://github.com/gradio-app/gradio/raw/main/demo/mini_leaderboard/assets/leaderboard_data.json"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import pandas as pd\n", "from pathlib import Path\n", "\n", "abs_path = Path(__file__).parent.absolute()\n", "\n", "df = pd.read_json(str(abs_path / \"assets/leaderboard_data.json\"))\n", "invisible_df = df.copy()\n", "\n", "\n", "COLS = [\n", "    \"T\",\n", "    \"Model\",\n", "    \"Average \u2b06\ufe0f\",\n", "    \"ARC\",\n", "    \"HellaSwag\",\n", "    \"MMLU\",\n", "    \"TruthfulQA\",\n", "    \"Winogrande\",\n", "    \"GSM8K\",\n", "    \"Type\",\n", "    \"Architecture\",\n", "    \"Precision\",\n", "    \"Merged\",\n", "    \"Hub License\",\n", "    \"#Params (B)\",\n", "    \"Hub \u2764\ufe0f\",\n", "    \"Model sha\",\n", "    \"model_name_for_query\",\n", "]\n", "ON_LOAD_COLS = [\n", "    \"T\",\n", "    \"Model\",\n", "    \"Average \u2b06\ufe0f\",\n", "    \"ARC\",\n", "    \"HellaSwag\",\n", "    \"MMLU\",\n", "    \"TruthfulQA\",\n", "    \"Winogrande\",\n", "    \"GSM8K\",\n", "    \"model_name_for_query\",\n", "]\n", "TYPES = [\n", "    \"str\",\n", "    \"markdown\",\n", "    \"number\",\n", "    \"number\",\n", "    \"number\",\n", "    \"number\",\n", "    \"number\",\n", "    \"number\",\n", "    \"number\",\n", "    \"str\",\n", "    \"str\",\n", "    \"str\",\n", "    \"str\",\n", "    \"bool\",\n", "    \"str\",\n", "    \"number\",\n", "    \"number\",\n", "    \"bool\",\n", "    \"str\",\n", "    \"bool\",\n", "    \"bool\",\n", "    \"str\",\n", "]\n", "NUMERIC_INTERVALS = {\n", "    \"?\": pd.Interval(-1, 0, closed=\"right\"),\n", "    \"~1.5\": pd.Interval(0, 2, closed=\"right\"),\n", "    \"~3\": pd.Interval(2, 4, closed=\"right\"),\n", "    \"~7\": pd.Interval(4, 9, closed=\"right\"),\n", "    \"~13\": pd.Interval(9, 20, closed=\"right\"),\n", "    \"~35\": pd.Interval(20, 45, closed=\"right\"),\n", "    \"~60\": pd.Interval(45, 70, closed=\"right\"),\n", "    \"70+\": pd.Interval(70, 10000, closed=\"right\"),\n", "}\n", "MODEL_TYPE = [str(s) for s in df[\"T\"].unique()]\n", "Precision = [str(s) for s in df[\"Precision\"].unique()]\n", "\n", "\n", "# Searching and filtering\n", "def update_table(\n", "    hidden_df: pd.DataFrame,\n", "    columns: list,\n", "    type_query: list,\n", "    precision_query: str,\n", "    size_query: list,\n", "    query: str,\n", "):\n", "    filtered_df = filter_models(hidden_df, type_query, size_query, precision_query)\n", "    filtered_df = filter_queries(query, filtered_df)\n", "    df = select_columns(filtered_df, columns)\n", "    return df\n", "\n", "\n", "def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:\n", "    return df[(df[\"model_name_for_query\"].str.contains(query, case=False))]\n", "\n", "\n", "def select_columns(df: pd.DataFrame, columns: list) -> pd.DataFrame:\n", "    # We use COLS to maintain sorting\n", "    filtered_df = df[[c for c in COLS if c in df.columns and c in columns]]\n", "    return filtered_df\n", "\n", "\n", "def filter_queries(query: str, filtered_df: pd.DataFrame) -> pd.DataFrame:\n", "    final_df = []\n", "    if query != \"\":\n", "        queries = [q.strip() for q in query.split(\";\")]\n", "        for _q in queries:\n", "            _q = _q.strip()\n", "            if _q != \"\":\n", "                temp_filtered_df = search_table(filtered_df, _q)\n", "                if len(temp_filtered_df) > 0:\n", "                    final_df.append(temp_filtered_df)\n", "        if len(final_df) > 0:\n", "            filtered_df = pd.concat(final_df)\n", "            filtered_df = filtered_df.drop_duplicates(\n", "                subset=[\"Model\", \"Precision\", \"Model sha\"]\n", "            )\n", "\n", "    return filtered_df\n", "\n", "\n", "def filter_models(\n", "    df: pd.DataFrame,\n", "    type_query: list,\n", "    size_query: list,\n", "    precision_query: list,\n", ") -> pd.DataFrame:\n", "    # Show all models\n", "    filtered_df = df\n", "\n", "    type_emoji = [t[0] for t in type_query]\n", "    filtered_df = filtered_df.loc[df[\"T\"].isin(type_emoji)]\n", "    filtered_df = filtered_df.loc[df[\"Precision\"].isin(precision_query + [\"None\"])]\n", "\n", "    numeric_interval = pd.IntervalIndex(\n", "        sorted([NUMERIC_INTERVALS[s] for s in size_query])\n", "    )\n", "    params_column = pd.to_numeric(df[\"#Params (B)\"], errors=\"coerce\")\n", "    mask = params_column.apply(lambda x: any(numeric_interval.contains(x)))\n", "    filtered_df = filtered_df.loc[mask]\n", "\n", "    return filtered_df\n", "\n", "\n", "demo = gr.Blocks(css=str(abs_path / \"assets/leaderboard_data.json\"))\n", "with demo:\n", "    gr.Markdown(\"\"\"Test Space of the LLM Leaderboard\"\"\", elem_classes=\"markdown-text\")\n", "\n", "    with gr.Tabs(elem_classes=\"tab-buttons\") as tabs:\n", "        with gr.TabItem(\"\ud83c\udfc5 LLM Benchmark\", elem_id=\"llm-benchmark-tab-table\", id=0):\n", "            with gr.Row():\n", "                with gr.Column():\n", "                    with gr.Row():\n", "                        search_bar = gr.Textbox(\n", "                            placeholder=\" \ud83d\udd0d Search for your model (separate multiple queries with `;`) and press ENTER...\",\n", "                            show_label=False,\n", "                            elem_id=\"search-bar\",\n", "                        )\n", "                    with gr.Row():\n", "                        shown_columns = gr.CheckboxGroup(\n", "                            choices=COLS,\n", "                            value=ON_LOAD_COLS,\n", "                            label=\"Select columns to show\",\n", "                            elem_id=\"column-select\",\n", "                            interactive=True,\n", "                        )\n", "                with gr.Column(min_width=320):\n", "                    filter_columns_type = gr.CheckboxGroup(\n", "                        label=\"Model types\",\n", "                        choices=MODEL_TYPE,\n", "                        value=MODEL_TYPE,\n", "                        interactive=True,\n", "                        elem_id=\"filter-columns-type\",\n", "                    )\n", "                    filter_columns_precision = gr.CheckboxGroup(\n", "                        label=\"Precision\",\n", "                        choices=Precision,\n", "                        value=Precision,\n", "                        interactive=True,\n", "                        elem_id=\"filter-columns-precision\",\n", "                    )\n", "                    filter_columns_size = gr.CheckboxGroup(\n", "                        label=\"Model sizes (in billions of parameters)\",\n", "                        choices=list(NUMERIC_INTERVALS.keys()),\n", "                        value=list(NUMERIC_INTERVALS.keys()),\n", "                        interactive=True,\n", "                        elem_id=\"filter-columns-size\",\n", "                    )\n", "\n", "            leaderboard_table = gr.components.Dataframe(\n", "                value=df[ON_LOAD_COLS],\n", "                headers=ON_LOAD_COLS,\n", "                datatype=TYPES,\n", "                elem_id=\"leaderboard-table\",\n", "                interactive=False,\n", "                visible=True,\n", "                column_widths=[\"2%\", \"33%\"],\n", "            )\n", "\n", "            # Dummy leaderboard for handling the case when the user uses backspace key\n", "            hidden_leaderboard_table_for_search = gr.components.Dataframe(\n", "                value=invisible_df[COLS],\n", "                headers=COLS,\n", "                datatype=TYPES,\n", "                visible=False,\n", "            )\n", "            search_bar.submit(\n", "                update_table,\n", "                [\n", "                    hidden_leaderboard_table_for_search,\n", "                    shown_columns,\n", "                    filter_columns_type,\n", "                    filter_columns_precision,\n", "                    filter_columns_size,\n", "                    search_bar,\n", "                ],\n", "                leaderboard_table,\n", "            )\n", "            for selector in [\n", "                shown_columns,\n", "                filter_columns_type,\n", "                filter_columns_precision,\n", "                filter_columns_size,\n", "            ]:\n", "                selector.change(\n", "                    update_table,\n", "                    [\n", "                        hidden_leaderboard_table_for_search,\n", "                        shown_columns,\n", "                        filter_columns_type,\n", "                        filter_columns_precision,\n", "                        filter_columns_size,\n", "                        search_bar,\n", "                    ],\n", "                    leaderboard_table,\n", "                    queue=True,\n", "                )\n", "\n", "\n", "if __name__ == \"__main__\":\n", "    demo.queue(default_concurrency_limit=40).launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}