freddyaboulton HF staff commited on
Commit
17fbcb8
1 Parent(s): 697d377

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. run.ipynb +1 -1
README.md CHANGED
@@ -5,7 +5,7 @@ emoji: 🔥
5
  colorFrom: indigo
6
  colorTo: indigo
7
  sdk: gradio
8
- sdk_version: 3.47.1
9
  app_file: run.py
10
  pinned: false
11
  hf_oauth: true
 
5
  colorFrom: indigo
6
  colorTo: indigo
7
  sdk: gradio
8
+ sdk_version: 3.48.0
9
  app_file: run.py
10
  pinned: false
11
  hf_oauth: true
run.ipynb CHANGED
@@ -1 +1 @@
1
- {"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: xgboost-income-prediction-with-explainability\n", "### This demo takes in 12 inputs from the user in dropdowns and sliders and predicts income. It also has a separate button for explaining the prediction.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy==1.23.2 matplotlib shap xgboost==1.7.6 pandas datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import shap\n", "import xgboost as xgb\n", "from datasets import load_dataset\n", "\n", "\n", "dataset = load_dataset(\"scikit-learn/adult-census-income\")\n", "X_train = dataset[\"train\"].to_pandas()\n", "_ = X_train.pop(\"fnlwgt\")\n", "_ = X_train.pop(\"race\")\n", "y_train = X_train.pop(\"income\")\n", "y_train = (y_train == \">50K\").astype(int)\n", "categorical_columns = [\n", " \"workclass\",\n", " \"education\",\n", " \"marital.status\",\n", " \"occupation\",\n", " \"relationship\",\n", " \"sex\",\n", " \"native.country\",\n", "]\n", "X_train = X_train.astype({col: \"category\" for col in categorical_columns})\n", "data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)\n", "model = xgb.train(params={\"objective\": \"binary:logistic\"}, dtrain=data)\n", "explainer = shap.TreeExplainer(model)\n", "\n", "def predict(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))\n", " return {\">50K\": float(pos_pred[0]), \"<=50K\": 1 - float(pos_pred[0])}\n", "\n", "\n", "def interpret(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))\n", " scores_desc = list(zip(shap_values[0], X_train.columns))\n", " scores_desc = sorted(scores_desc)\n", " fig_m = plt.figure(tight_layout=True)\n", " plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])\n", " plt.title(\"Feature Shap Values\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Feature\")\n", " plt.tight_layout()\n", " return fig_m\n", "\n", "\n", "unique_class = sorted(X_train[\"workclass\"].unique())\n", "unique_education = sorted(X_train[\"education\"].unique())\n", "unique_marital_status = sorted(X_train[\"marital.status\"].unique())\n", "unique_relationship = sorted(X_train[\"relationship\"].unique())\n", "unique_occupation = sorted(X_train[\"occupation\"].unique())\n", "unique_sex = sorted(X_train[\"sex\"].unique())\n", "unique_country = sorted(X_train[\"native.country\"].unique())\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"\n", " **Income Classification with XGBoost \ud83d\udcb0**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column():\n", " age = gr.Slider(label=\"Age\", minimum=17, maximum=90, step=1, randomize=True)\n", " work_class = gr.Dropdown(\n", " label=\"Workclass\",\n", " choices=unique_class,\n", " value=lambda: random.choice(unique_class),\n", " )\n", " education = gr.Dropdown(\n", " label=\"Education Level\",\n", " choices=unique_education,\n", " value=lambda: random.choice(unique_education),\n", " )\n", " years = gr.Slider(\n", " label=\"Years of schooling\",\n", " minimum=1,\n", " maximum=16,\n", " step=1,\n", " randomize=True,\n", " )\n", " marital_status = gr.Dropdown(\n", " label=\"Marital Status\",\n", " choices=unique_marital_status,\n", " value=lambda: random.choice(unique_marital_status),\n", " )\n", " occupation = gr.Dropdown(\n", " label=\"Occupation\",\n", " choices=unique_occupation,\n", " value=lambda: random.choice(unique_occupation),\n", " )\n", " relationship = gr.Dropdown(\n", " label=\"Relationship Status\",\n", " choices=unique_relationship,\n", " value=lambda: random.choice(unique_relationship),\n", " )\n", " sex = gr.Dropdown(\n", " label=\"Sex\", choices=unique_sex, value=lambda: random.choice(unique_sex)\n", " )\n", " capital_gain = gr.Slider(\n", " label=\"Capital Gain\",\n", " minimum=0,\n", " maximum=100000,\n", " step=500,\n", " randomize=True,\n", " )\n", " capital_loss = gr.Slider(\n", " label=\"Capital Loss\", minimum=0, maximum=10000, step=500, randomize=True\n", " )\n", " hours_per_week = gr.Slider(\n", " label=\"Hours Per Week Worked\", minimum=1, maximum=99, step=1\n", " )\n", " country = gr.Dropdown(\n", " label=\"Native Country\",\n", " choices=unique_country,\n", " value=lambda: random.choice(unique_country),\n", " )\n", " with gr.Column():\n", " label = gr.Label()\n", " plot = gr.Plot()\n", " with gr.Row():\n", " predict_btn = gr.Button(value=\"Predict\")\n", " interpret_btn = gr.Button(value=\"Explain\")\n", " predict_btn.click(\n", " predict,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[label],\n", " )\n", " interpret_btn.click(\n", " interpret,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[plot],\n", " )\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
 
1
+ {"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: xgboost-income-prediction-with-explainability\n", "### This demo takes in 12 inputs from the user in dropdowns and sliders and predicts income. It also has a separate button for explaining the prediction.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy==1.23.2 matplotlib shap xgboost==1.7.6 pandas datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import shap\n", "import xgboost as xgb\n", "from datasets import load_dataset\n", "\n", "\n", "dataset = load_dataset(\"scikit-learn/adult-census-income\")\n", "X_train = dataset[\"train\"].to_pandas()\n", "_ = X_train.pop(\"fnlwgt\")\n", "_ = X_train.pop(\"race\")\n", "y_train = X_train.pop(\"income\")\n", "y_train = (y_train == \">50K\").astype(int)\n", "categorical_columns = [\n", " \"workclass\",\n", " \"education\",\n", " \"marital.status\",\n", " \"occupation\",\n", " \"relationship\",\n", " \"sex\",\n", " \"native.country\",\n", "]\n", "X_train = X_train.astype({col: \"category\" for col in categorical_columns})\n", "data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)\n", "model = xgb.train(params={\"objective\": \"binary:logistic\"}, dtrain=data)\n", "explainer = shap.TreeExplainer(model)\n", "\n", "def predict(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))\n", " return {\">50K\": float(pos_pred[0]), \"<=50K\": 1 - float(pos_pred[0])}\n", "\n", "\n", "def interpret(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))\n", " scores_desc = list(zip(shap_values[0], X_train.columns))\n", " scores_desc = sorted(scores_desc)\n", " fig_m = plt.figure(tight_layout=True)\n", " plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])\n", " plt.title(\"Feature Shap Values\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Feature\")\n", " plt.tight_layout()\n", " return fig_m\n", "\n", "\n", "unique_class = sorted(X_train[\"workclass\"].unique())\n", "unique_education = sorted(X_train[\"education\"].unique())\n", "unique_marital_status = sorted(X_train[\"marital.status\"].unique())\n", "unique_relationship = sorted(X_train[\"relationship\"].unique())\n", "unique_occupation = sorted(X_train[\"occupation\"].unique())\n", "unique_sex = sorted(X_train[\"sex\"].unique())\n", "unique_country = sorted(X_train[\"native.country\"].unique())\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"\n", " **Income Classification with XGBoost \ud83d\udcb0**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column():\n", " age = gr.Slider(label=\"Age\", minimum=17, maximum=90, step=1, randomize=True)\n", " work_class = gr.Dropdown(\n", " label=\"Workclass\",\n", " choices=unique_class,\n", " value=lambda: random.choice(unique_class),\n", " )\n", " education = gr.Dropdown(\n", " label=\"Education Level\",\n", " choices=unique_education,\n", " value=lambda: random.choice(unique_education),\n", " )\n", " years = gr.Slider(\n", " label=\"Years of schooling\",\n", " minimum=1,\n", " maximum=16,\n", " step=1,\n", " randomize=True,\n", " )\n", " marital_status = gr.Dropdown(\n", " label=\"Marital Status\",\n", " choices=unique_marital_status,\n", " value=lambda: random.choice(unique_marital_status),\n", " )\n", " occupation = gr.Dropdown(\n", " label=\"Occupation\",\n", " choices=unique_occupation,\n", " value=lambda: random.choice(unique_occupation),\n", " )\n", " relationship = gr.Dropdown(\n", " label=\"Relationship Status\",\n", " choices=unique_relationship,\n", " value=lambda: random.choice(unique_relationship),\n", " )\n", " sex = gr.Dropdown(\n", " label=\"Sex\", choices=unique_sex, value=lambda: random.choice(unique_sex)\n", " )\n", " capital_gain = gr.Slider(\n", " label=\"Capital Gain\",\n", " minimum=0,\n", " maximum=100000,\n", " step=500,\n", " randomize=True,\n", " )\n", " capital_loss = gr.Slider(\n", " label=\"Capital Loss\", minimum=0, maximum=10000, step=500, randomize=True\n", " )\n", " hours_per_week = gr.Slider(\n", " label=\"Hours Per Week Worked\", minimum=1, maximum=99, step=1\n", " )\n", " country = gr.Dropdown(\n", " label=\"Native Country\",\n", " choices=unique_country,\n", " value=lambda: random.choice(unique_country),\n", " )\n", " with gr.Column():\n", " label = gr.Label()\n", " plot = gr.Plot()\n", " with gr.Row():\n", " predict_btn = gr.Button(value=\"Predict\")\n", " interpret_btn = gr.Button(value=\"Explain\")\n", " predict_btn.click(\n", " predict,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[label],\n", " )\n", " interpret_btn.click(\n", " interpret,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[plot],\n", " )\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}