File size: 4,517 Bytes
71f951f
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: titanic_survival"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio scikit-learn numpy pandas"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/titanic.csv https://github.com/gradio-app/gradio/raw/main/demo/titanic_survival/files/titanic.csv"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import os\n", "\n", "import pandas as pd\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "\n", "import gradio as gr\n", "\n", "current_dir = os.path.dirname(os.path.realpath(__file__))\n", "data = pd.read_csv(os.path.join(current_dir, \"files/titanic.csv\"))\n", "\n", "\n", "def encode_age(df):\n", "    df.Age = df.Age.fillna(-0.5)\n", "    bins = (-1, 0, 5, 12, 18, 25, 35, 60, 120)\n", "    categories = pd.cut(df.Age, bins, labels=False)\n", "    df.Age = categories\n", "    return df\n", "\n", "\n", "def encode_fare(df):\n", "    df.Fare = df.Fare.fillna(-0.5)\n", "    bins = (-1, 0, 8, 15, 31, 1000)\n", "    categories = pd.cut(df.Fare, bins, labels=False)\n", "    df.Fare = categories\n", "    return df\n", "\n", "\n", "def encode_df(df):\n", "    df = encode_age(df)\n", "    df = encode_fare(df)\n", "    sex_mapping = {\"male\": 0, \"female\": 1}\n", "    df = df.replace({\"Sex\": sex_mapping})\n", "    embark_mapping = {\"S\": 1, \"C\": 2, \"Q\": 3}\n", "    df = df.replace({\"Embarked\": embark_mapping})\n", "    df.Embarked = df.Embarked.fillna(0)\n", "    df[\"Company\"] = 0\n", "    df.loc[(df[\"SibSp\"] > 0), \"Company\"] = 1\n", "    df.loc[(df[\"Parch\"] > 0), \"Company\"] = 2\n", "    df.loc[(df[\"SibSp\"] > 0) & (df[\"Parch\"] > 0), \"Company\"] = 3\n", "    df = df[\n", "        [\n", "            \"PassengerId\",\n", "            \"Pclass\",\n", "            \"Sex\",\n", "            \"Age\",\n", "            \"Fare\",\n", "            \"Embarked\",\n", "            \"Company\",\n", "            \"Survived\",\n", "        ]\n", "    ]\n", "    return df\n", "\n", "\n", "train = encode_df(data)\n", "\n", "X_all = train.drop([\"Survived\", \"PassengerId\"], axis=1)\n", "y_all = train[\"Survived\"]\n", "\n", "num_test = 0.20\n", "X_train, X_test, y_train, y_test = train_test_split(\n", "    X_all, y_all, test_size=num_test, random_state=23\n", ")\n", "\n", "clf = RandomForestClassifier()\n", "clf.fit(X_train, y_train)\n", "predictions = clf.predict(X_test)\n", "\n", "\n", "def predict_survival(passenger_class, is_male, age, company, fare, embark_point):\n", "    if passenger_class is None or embark_point is None:\n", "        return None\n", "    df = pd.DataFrame.from_dict(\n", "        {\n", "            \"Pclass\": [passenger_class + 1],\n", "            \"Sex\": [0 if is_male else 1],\n", "            \"Age\": [age],\n", "            \"Fare\": [fare],\n", "            \"Embarked\": [embark_point + 1],\n", "            \"Company\": [\n", "                (1 if \"Sibling\" in company else 0) + (2 if \"Child\" in company else 0)\n", "            ]\n", "        }\n", "    )\n", "    df = encode_age(df)\n", "    df = encode_fare(df)\n", "    pred = clf.predict_proba(df)[0]\n", "    return {\"Perishes\": float(pred[0]), \"Survives\": float(pred[1])}\n", "\n", "\n", "demo = gr.Interface(\n", "    predict_survival,\n", "    [\n", "        gr.Dropdown([\"first\", \"second\", \"third\"], type=\"index\"),\n", "        \"checkbox\",\n", "        gr.Slider(0, 80, value=25),\n", "        gr.CheckboxGroup([\"Sibling\", \"Child\"], label=\"Travelling with (select all)\"),\n", "        gr.Number(value=20),\n", "        gr.Radio([\"S\", \"C\", \"Q\"], type=\"index\"),\n", "    ],\n", "    \"label\",\n", "    examples=[\n", "        [\"first\", True, 30, [], 50, \"S\"],\n", "        [\"second\", False, 40, [\"Sibling\", \"Child\"], 10, \"Q\"],\n", "        [\"third\", True, 30, [\"Child\"], 20, \"S\"],\n", "    ],\n", "    live=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}