{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: titanic_survival"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio scikit-learn numpy pandas"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/titanic.csv https://github.com/gradio-app/gradio/raw/main/demo/titanic_survival/files/titanic.csv"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import os\n", "\n", "import pandas as pd\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "\n", "import gradio as gr\n", "\n", "current_dir = os.path.dirname(os.path.realpath(__file__))\n", "data = pd.read_csv(os.path.join(current_dir, \"files/titanic.csv\"))\n", "\n", "\n", "def encode_age(df):\n", " df.Age = df.Age.fillna(-0.5)\n", " bins = (-1, 0, 5, 12, 18, 25, 35, 60, 120)\n", " categories = pd.cut(df.Age, bins, labels=False)\n", " df.Age = categories\n", " return df\n", "\n", "\n", "def encode_fare(df):\n", " df.Fare = df.Fare.fillna(-0.5)\n", " bins = (-1, 0, 8, 15, 31, 1000)\n", " categories = pd.cut(df.Fare, bins, labels=False)\n", " df.Fare = categories\n", " return df\n", "\n", "\n", "def encode_df(df):\n", " df = encode_age(df)\n", " df = encode_fare(df)\n", " sex_mapping = {\"male\": 0, \"female\": 1}\n", " df = df.replace({\"Sex\": sex_mapping})\n", " embark_mapping = {\"S\": 1, \"C\": 2, \"Q\": 3}\n", " df = df.replace({\"Embarked\": embark_mapping})\n", " df.Embarked = df.Embarked.fillna(0)\n", " df[\"Company\"] = 0\n", " df.loc[(df[\"SibSp\"] > 0), \"Company\"] = 1\n", " df.loc[(df[\"Parch\"] > 0), \"Company\"] = 2\n", " df.loc[(df[\"SibSp\"] > 0) & (df[\"Parch\"] > 0), \"Company\"] = 3\n", " df = df[\n", " [\n", " \"PassengerId\",\n", " \"Pclass\",\n", " \"Sex\",\n", " \"Age\",\n", " \"Fare\",\n", " \"Embarked\",\n", " \"Company\",\n", " \"Survived\",\n", " ]\n", " ]\n", " return df\n", "\n", "\n", "train = encode_df(data)\n", "\n", "X_all = train.drop([\"Survived\", \"PassengerId\"], axis=1)\n", "y_all = train[\"Survived\"]\n", "\n", "num_test = 0.20\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X_all, y_all, test_size=num_test, random_state=23\n", ")\n", "\n", "clf = RandomForestClassifier()\n", "clf.fit(X_train, y_train)\n", "predictions = clf.predict(X_test)\n", "\n", "\n", "def predict_survival(passenger_class, is_male, age, company, fare, embark_point):\n", " if passenger_class is None or embark_point is None:\n", " return None\n", " df = pd.DataFrame.from_dict(\n", " {\n", " \"Pclass\": [passenger_class + 1],\n", " \"Sex\": [0 if is_male else 1],\n", " \"Age\": [age],\n", " \"Fare\": [fare],\n", " \"Embarked\": [embark_point + 1],\n", " \"Company\": [\n", " (1 if \"Sibling\" in company else 0) + (2 if \"Child\" in company else 0)\n", " ]\n", " }\n", " )\n", " df = encode_age(df)\n", " df = encode_fare(df)\n", " pred = clf.predict_proba(df)[0]\n", " return {\"Perishes\": float(pred[0]), \"Survives\": float(pred[1])}\n", "\n", "\n", "demo = gr.Interface(\n", " predict_survival,\n", " [\n", " gr.Dropdown([\"first\", \"second\", \"third\"], type=\"index\"),\n", " \"checkbox\",\n", " gr.Slider(0, 80, value=25),\n", " gr.CheckboxGroup([\"Sibling\", \"Child\"], label=\"Travelling with (select all)\"),\n", " gr.Number(value=20),\n", " gr.Radio([\"S\", \"C\", \"Q\"], type=\"index\"),\n", " ],\n", " \"label\",\n", " examples=[\n", " [\"first\", True, 30, [], 50, \"S\"],\n", " [\"second\", False, 40, [\"Sibling\", \"Child\"], 10, \"Q\"],\n", " [\"third\", True, 30, [\"Child\"], 20, \"S\"],\n", " ],\n", " live=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}