{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "kBcnQ90G8_o-" ], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# Import libaries\n" ], "metadata": { "id": "38OGDbSI84N-" } }, { "cell_type": "code", "source": [ "!pip install skops" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "collapsed": true, "id": "XfXjrEjx6dOM", "outputId": "e93b1a5f-c269-4c3c-e6db-c6ece761c7c4" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting skops\n", " Downloading skops-0.10.0-py3-none-any.whl.metadata (5.8 kB)\n", "Requirement already satisfied: scikit-learn>=0.24 in /usr/local/lib/python3.10/dist-packages (from skops) (1.5.2)\n", "Requirement already satisfied: huggingface-hub>=0.17.0 in /usr/local/lib/python3.10/dist-packages (from skops) (0.26.2)\n", "Requirement already satisfied: tabulate>=0.8.8 in /usr/local/lib/python3.10/dist-packages (from skops) (0.9.0)\n", "Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.10/dist-packages (from skops) (24.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->skops) (3.16.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->skops) (2024.10.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->skops) (6.0.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->skops) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->skops) (4.66.6)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->skops) (4.12.2)\n", "Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24->skops) (1.26.4)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24->skops) (1.13.1)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24->skops) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24->skops) (3.5.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->skops) (3.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->skops) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->skops) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->skops) (2024.8.30)\n", "Downloading skops-0.10.0-py3-none-any.whl (121 kB)\n", "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/121.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━\u001b[0m \u001b[32m112.6/121.9 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.9/121.9 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: skops\n", "Successfully installed skops-0.10.0\n" ] } ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "zrmBieIvzkwr" }, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import pandas as pd\n", "import numpy as np\n", "\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from pathlib import Path\n", "from tempfile import mkdtemp, mkstemp\n", "\n", "import skops.io as sio\n", "from skops import card, hub_utils" ] }, { "cell_type": "markdown", "source": [ "# Load datasets" ], "metadata": { "id": "kBcnQ90G8_o-" } }, { "cell_type": "code", "source": [ "mpl_df = pd.read_csv(\"/content/mpl_id_s14.csv\")\n", "attributes_df = pd.read_csv(\"/content/mlbb_heroes_attribute.csv\")" ], "metadata": { "id": "c4JKbn5x7_JW" }, "execution_count": 3, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Drop unused columns" ], "metadata": { "id": "wTu-lqvP9EaI" } }, { "cell_type": "code", "source": [ "# drop unused columns\n", "mpl_df.drop(columns=['no','week','date','match','game','match_length','blue_team','red_team'], inplace=True)\n", "attributes_df.drop(columns=['id','main_role','secondary_role','main_damage_type'], inplace=True)" ], "metadata": { "id": "1t6Viu528I1t" }, "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Def functions to feature enginering" ], "metadata": { "id": "f_hLrJ1O9H_K" } }, { "cell_type": "code", "source": [ "# Function to add sum attributes for each side\n", "def calculate_side_features(mpl_df, attribute_df):\n", " # Create a hero mapping dictionary for fast lookups\n", " attribute_dict = attribute_df.set_index('hero')[['durability', 'offense', 'control_effects', 'difficulty', 'early', 'mid', 'late']].to_dict('index')\n", "\n", " # Internal function to calculate the sum of attributes from one side (blue or red)\n", " def sum_attributes(row, side):\n", " positions = ['explaner', 'jungler', 'midlaner', 'goldlaner', 'roamer']\n", " features = ['durability', 'offense', 'control_effects', 'difficulty', 'early', 'mid', 'late']\n", " side_sums = {f\"{side}_{feature}\": 0 for feature in features}\n", "\n", " for pos in positions:\n", " hero = row[f\"{side}_{pos}\"]\n", " if hero in attribute_dict:\n", " for feature in features:\n", " side_sums[f\"{side}_{feature}\"] += attribute_dict[hero][feature]\n", "\n", " return pd.Series(side_sums)\n", "\n", " # Apply the above function to each row in mpl_df\n", " blue_features = mpl_df.apply(lambda row: sum_attributes(row, 'blue'), axis=1)\n", " red_features = mpl_df.apply(lambda row: sum_attributes(row, 'red'), axis=1)\n", "\n", " # Merge the results into mpl_df\n", " mpl_df = pd.concat([mpl_df, blue_features, red_features], axis=1)\n", "\n", " # Calculate the total power spike for each side\n", " mpl_df['blue_total_power_spike'] = mpl_df[['blue_early', 'blue_mid', 'blue_late']].sum(axis=1)\n", " mpl_df['red_total_power_spike'] = mpl_df[['red_early', 'red_mid', 'red_late']].sum(axis=1)\n", "\n", " return mpl_df" ], "metadata": { "id": "EUVghJRP8LvX" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "# Functions for feature engineering on DataFrame\n", "def perform_feature_engineering(mpl_df):\n", " # Feature engineering based on attribute differences\n", " mpl_df['durability_diff'] = mpl_df['blue_durability'] - mpl_df['red_durability']\n", " mpl_df['offense_diff'] = mpl_df['blue_offense'] - mpl_df['red_offense']\n", " mpl_df['control_effects_diff'] = mpl_df['blue_control_effects'] - mpl_df['red_control_effects']\n", " mpl_df['difficulty_diff'] = mpl_df['blue_difficulty'] - mpl_df['red_difficulty']\n", " mpl_df['power_spike_diff'] = mpl_df['blue_total_power_spike'] - mpl_df['red_total_power_spike']\n", "\n", " # Average value for role per team\n", " mpl_df['blue_avg_durability'] = mpl_df['blue_durability'] / 5\n", " mpl_df['red_avg_durability'] = mpl_df['red_durability'] / 5\n", " mpl_df['blue_avg_offense'] = mpl_df['blue_offense'] / 5\n", " mpl_df['red_avg_offense'] = mpl_df['red_offense'] / 5\n", "\n", " # Aggressiveness score\n", " mpl_df['blue_aggressiveness_score'] = mpl_df['blue_offense'] / mpl_df['blue_durability']\n", " mpl_df['red_aggressiveness_score'] = mpl_df['red_offense'] / mpl_df['red_durability']\n", "\n", " # Early-mid and mid-late comparison for the blue team\n", " mpl_df['blue_early_mid_ratio'] = np.where(mpl_df['blue_mid'] == 0, 1, mpl_df['blue_early'] / mpl_df['blue_mid'])\n", " mpl_df['blue_mid_late_ratio'] = np.where(mpl_df['blue_late'] == 0, 1, mpl_df['blue_mid'] / mpl_df['blue_late'])\n", "\n", " # Early-mid and mid-late comparison for the red team\n", " mpl_df['red_early_mid_ratio'] = np.where(mpl_df['red_mid'] == 0, 1, mpl_df['red_early'] / mpl_df['red_mid'])\n", " mpl_df['red_mid_late_ratio'] = np.where(mpl_df['red_late'] == 0, 1, mpl_df['red_mid'] / mpl_df['red_late'])\n", "\n", " # Drop unused position columns\n", " mpl_df.drop(columns=[\n", " 'blue_explaner', 'blue_jungler', 'blue_midlaner', 'blue_goldlaner', 'blue_roamer',\n", " 'red_explaner', 'red_jungler', 'red_midlaner', 'red_goldlaner', 'red_roamer'\n", " ], inplace=True)\n", "\n", " # Drop unnecessary columns for the model\n", " mpl_df = mpl_df.drop(columns=[\n", " 'blue_durability', 'blue_offense', 'blue_control_effects', 'blue_difficulty',\n", " 'red_durability', 'red_offense', 'red_control_effects', 'red_difficulty',\n", " 'blue_total_power_spike',\n", " 'blue_early', 'blue_mid', 'blue_late',\n", " 'red_total_power_spike',\n", " 'red_early', 'red_mid', 'red_late'\n", " ])\n", "\n", " return mpl_df" ], "metadata": { "id": "3IdkY9Qh8P32" }, "execution_count": 6, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Feature enginering" ], "metadata": { "id": "R2zkJH4n9O9r" } }, { "cell_type": "code", "source": [ "# calculate mpl_df and attributes_df\n", "mpl_attr_df = calculate_side_features(mpl_df, attributes_df)\n", "mpl_attr_df.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 256 }, "id": "ROqemk878N8_", "outputId": "db58df78-28c3-40cd-fbb9-c66aaeb321e5" }, "execution_count": 7, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " blue_explaner blue_jungler blue_midlaner blue_goldlaner blue_roamer \\\n", "0 Edith Julian Luo Yi Moskov Minotaur \n", "1 Thamuz Nolan Valentina Harith Minotaur \n", "2 Hylos Julian Valentina Claude Edith \n", "3 X.Borg Julian Valentina Moskov Arlott \n", "4 Paquito Julian Valentina Moskov Hylos \n", "\n", " red_explaner red_jungler red_midlaner red_goldlaner red_roamer ... \\\n", "0 Arlott Roger Valentina Harith Chou ... \n", "1 Arlott Joy Vexana Roger Edith ... \n", "2 Ruby Nolan Vexana Moskov Khufra ... \n", "3 Hylos Roger Zhask Claude Guinevere ... \n", "4 X.Borg Roger Novaria Harith Guinevere ... \n", "\n", " blue_late red_durability red_offense red_control_effects red_difficulty \\\n", "0 3 26 27 23 30 \n", "1 1 24 29 22 23 \n", "2 2 25 25 32 21 \n", "3 2 24 27 21 22 \n", "4 1 25 26 13 24 \n", "\n", " red_early red_mid red_late blue_total_power_spike red_total_power_spike \n", "0 3 3 2 8 8 \n", "1 2 4 2 8 8 \n", "2 3 3 1 9 7 \n", "3 2 4 3 8 9 \n", "4 1 5 4 7 10 \n", "\n", "[5 rows x 27 columns]" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
blue_explanerblue_junglerblue_midlanerblue_goldlanerblue_roamerred_explanerred_junglerred_midlanerred_goldlanerred_roamer...blue_latered_durabilityred_offensered_control_effectsred_difficultyred_earlyred_midred_lateblue_total_power_spikered_total_power_spike
0EdithJulianLuo YiMoskovMinotaurArlottRogerValentinaHarithChou...32627233033288
1ThamuzNolanValentinaHarithMinotaurArlottJoyVexanaRogerEdith...12429222324288
2HylosJulianValentinaClaudeEdithRubyNolanVexanaMoskovKhufra...22525322133197
3X.BorgJulianValentinaMoskovArlottHylosRogerZhaskClaudeGuinevere...22427212224389
4PaquitoJulianValentinaMoskovHylosX.BorgRogerNovariaHarithGuinevere...125261324154710
\n", "

5 rows × 27 columns

\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "mpl_attr_df" } }, "metadata": {}, "execution_count": 7 } ] }, { "cell_type": "code", "source": [ "# feature enginering\n", "mpl_df_transformed = perform_feature_engineering(mpl_attr_df)\n", "mpl_df_transformed = mpl_df_transformed.round(3)\n", "mpl_df_transformed.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 226 }, "id": "IvGErtHo8Sk3", "outputId": "f5c666ff-9150-46d8-eec6-15be90cecf5b" }, "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " result durability_diff offense_diff control_effects_diff \\\n", "0 RED 1 -1 12 \n", "1 BLUE 5 -4 -1 \n", "2 BLUE 6 1 2 \n", "3 RED 3 3 6 \n", "4 BLUE 4 2 20 \n", "\n", " difficulty_diff power_spike_diff blue_avg_durability red_avg_durability \\\n", "0 -11 0 5.4 5.2 \n", "1 4 0 5.8 4.8 \n", "2 1 2 6.2 5.0 \n", "3 2 -1 5.4 4.8 \n", "4 2 -3 5.8 5.0 \n", "\n", " blue_avg_offense red_avg_offense blue_aggressiveness_score \\\n", "0 5.2 5.4 0.963 \n", "1 5.0 5.8 0.862 \n", "2 5.2 5.0 0.839 \n", "3 6.0 5.4 1.111 \n", "4 5.6 5.2 0.966 \n", "\n", " red_aggressiveness_score blue_early_mid_ratio blue_mid_late_ratio \\\n", "0 1.038 0.667 1.0 \n", "1 1.208 0.750 4.0 \n", "2 1.000 0.750 2.0 \n", "3 1.125 1.000 1.5 \n", "4 1.040 1.000 3.0 \n", "\n", " red_early_mid_ratio red_mid_late_ratio \n", "0 1.0 1.500 \n", "1 0.5 2.000 \n", "2 1.0 3.000 \n", "3 0.5 1.333 \n", "4 0.2 1.250 " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
resultdurability_diffoffense_diffcontrol_effects_diffdifficulty_diffpower_spike_diffblue_avg_durabilityred_avg_durabilityblue_avg_offensered_avg_offenseblue_aggressiveness_scorered_aggressiveness_scoreblue_early_mid_ratioblue_mid_late_ratiored_early_mid_ratiored_mid_late_ratio
0RED1-112-1105.45.25.25.40.9631.0380.6671.01.01.500
1BLUE5-4-1405.84.85.05.80.8621.2080.7504.00.52.000
2BLUE612126.25.05.25.00.8391.0000.7502.01.03.000
3RED3362-15.44.86.05.41.1111.1251.0001.50.51.333
4BLUE42202-35.85.05.65.20.9661.0401.0003.00.21.250
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "mpl_df_transformed", "summary": "{\n \"name\": \"mpl_df_transformed\",\n \"rows\": 212,\n \"fields\": [\n {\n \"column\": \"result\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"BLUE\",\n \"RED\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"durability_diff\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 4,\n \"min\": -14,\n \"max\": 11,\n \"num_unique_values\": 23,\n \"samples\": [\n -3,\n 7\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"offense_diff\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 5,\n \"min\": -12,\n \"max\": 19,\n \"num_unique_values\": 26,\n \"samples\": [\n 5,\n -8\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"control_effects_diff\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 7,\n \"min\": -22,\n \"max\": 20,\n \"num_unique_values\": 37,\n \"samples\": [\n -14,\n -2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"difficulty_diff\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 7,\n \"min\": -16,\n \"max\": 18,\n \"num_unique_values\": 35,\n \"samples\": [\n 16,\n 17\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"power_spike_diff\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": -4,\n \"max\": 4,\n \"num_unique_values\": 9,\n \"samples\": [\n 3,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"blue_avg_durability\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6214503060388611,\n \"min\": 3.6,\n \"max\": 7.4,\n \"num_unique_values\": 19,\n \"samples\": [\n 5.4,\n 4.6\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"red_avg_durability\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6430055511285113,\n \"min\": 3.8,\n \"max\": 7.4,\n \"num_unique_values\": 18,\n \"samples\": [\n 5.2,\n 4.8\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"blue_avg_offense\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6480312945415349,\n \"min\": 3.2,\n \"max\": 7.2,\n \"num_unique_values\": 18,\n \"samples\": [\n 5.2,\n 5.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"red_avg_offense\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6736092644226174,\n \"min\": 2.8,\n \"max\": 7.2,\n \"num_unique_values\": 19,\n \"samples\": [\n 5.4,\n 6.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"blue_aggressiveness_score\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.21542079891962987,\n \"min\": 0.571,\n \"max\": 2.0,\n \"num_unique_values\": 98,\n \"samples\": [\n 1.259,\n 1.318\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"red_aggressiveness_score\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2080560009696417,\n \"min\": 0.452,\n \"max\": 1.565,\n \"num_unique_values\": 109,\n \"samples\": [\n 0.571,\n 0.774\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"blue_early_mid_ratio\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.39883052760267895,\n \"min\": 0.2,\n \"max\": 2.0,\n \"num_unique_values\": 14,\n \"samples\": [\n 0.333,\n 0.4\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"blue_mid_late_ratio\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.9611989055114264,\n \"min\": 0.667,\n \"max\": 5.0,\n \"num_unique_values\": 12,\n \"samples\": [\n 1.25,\n 2.5\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"red_early_mid_ratio\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.3949263231878436,\n \"min\": 0.2,\n \"max\": 2.0,\n \"num_unique_values\": 13,\n \"samples\": [\n 1.25,\n 1.5\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"red_mid_late_ratio\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.0457086449633677,\n \"min\": 0.5,\n \"max\": 5.0,\n \"num_unique_values\": 13,\n \"samples\": [\n 5.0,\n 0.667\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 8 } ] }, { "cell_type": "markdown", "source": [ "# Split dataset" ], "metadata": { "id": "2cgtljuR9SFN" } }, { "cell_type": "code", "source": [ "X = mpl_df_transformed.drop(columns=['result']) # All columns except 'result'\n", "y = mpl_df_transformed['result'] # target column" ], "metadata": { "id": "ChhFq-XCINNy" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "# Split dataset 75/25\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42, stratify=y)" ], "metadata": { "id": "qWReEl_f9Wm9" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "# Check the size of each dataset\n", "print(f\"Training set size: {X_train.shape[0]}\")\n", "print(f\"Test set size: {X_test.shape[0]}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2PsEW4l7--cQ", "outputId": "d60c7d5b-e89a-4b9f-fbc6-8c8f11a9e9d1" }, "execution_count": 11, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Training set size: 159\n", "Test set size: 53\n" ] } ] }, { "cell_type": "code", "source": [ "y_train.value_counts()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 178 }, "id": "F7NBmeyTCTfR", "outputId": "3b27b220-1615-46bd-e726-9239284bd40d" }, "execution_count": 12, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "result\n", "RED 80\n", "BLUE 79\n", "Name: count, dtype: int64" ], "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
count
result
RED80
BLUE79
\n", "

" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "markdown", "source": [ "# Train model" ], "metadata": { "id": "MTr1h0Ts9WH4" } }, { "cell_type": "code", "source": [ "# Train model with Random Forest\n", "model = RandomForestClassifier(\n", " random_state=91,\n", " n_estimators=50,\n", " max_depth=2,\n", " min_samples_split=2,\n", " min_samples_leaf=8,\n", ")\n", "model.fit(X_train, y_train)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 98 }, "id": "hsrMfXt59gWP", "outputId": "32d56ac8-797b-44b3-a090-c8eebf143696" }, "execution_count": 13, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "RandomForestClassifier(max_depth=2, min_samples_leaf=8, n_estimators=50,\n", " random_state=91)" ], "text/html": [ "
RandomForestClassifier(max_depth=2, min_samples_leaf=8, n_estimators=50,\n",
              "                       random_state=91)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "markdown", "source": [ "# Inference" ], "metadata": { "id": "Y6IkbI5C9wUv" } }, { "cell_type": "code", "source": [ "# Make predictions on test data\n", "y_pred = model.predict(X_test)" ], "metadata": { "id": "7zPF9MGyCIBf" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "print(y_pred)\n", "print(y_test.values)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DfgEB06c9lyW", "outputId": "a0320a3f-3fc9-4a41-a681-807b70c9f0be" }, "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['BLUE' 'RED' 'BLUE' 'RED' 'BLUE' 'RED' 'RED' 'RED' 'RED' 'RED' 'RED'\n", " 'RED' 'RED' 'RED' 'RED' 'BLUE' 'RED' 'BLUE' 'BLUE' 'BLUE' 'BLUE' 'BLUE'\n", " 'RED' 'RED' 'BLUE' 'BLUE' 'BLUE' 'BLUE' 'RED' 'BLUE' 'BLUE' 'RED' 'BLUE'\n", " 'BLUE' 'BLUE' 'RED' 'RED' 'RED' 'RED' 'BLUE' 'RED' 'RED' 'RED' 'BLUE'\n", " 'BLUE' 'BLUE' 'RED' 'RED' 'BLUE' 'BLUE' 'BLUE' 'RED' 'RED']\n", "['BLUE' 'RED' 'BLUE' 'RED' 'BLUE' 'RED' 'BLUE' 'RED' 'BLUE' 'RED' 'BLUE'\n", " 'RED' 'RED' 'BLUE' 'RED' 'BLUE' 'RED' 'BLUE' 'RED' 'BLUE' 'BLUE' 'RED'\n", " 'BLUE' 'RED' 'BLUE' 'RED' 'RED' 'BLUE' 'RED' 'BLUE' 'RED' 'RED' 'BLUE'\n", " 'RED' 'RED' 'BLUE' 'BLUE' 'RED' 'RED' 'BLUE' 'RED' 'RED' 'BLUE' 'BLUE'\n", " 'BLUE' 'BLUE' 'BLUE' 'RED' 'BLUE' 'BLUE' 'BLUE' 'RED' 'RED']\n" ] } ] }, { "cell_type": "code", "source": [ "# accuration\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(f'Accuracy: {accuracy:.2f}')\n", "print()\n", "\n", "# Confusion matrix\n", "cm = confusion_matrix(y_test, y_pred,labels=model.classes_)\n", "print('Confusion Matrix:')\n", "print(cm)\n", "print()\n", "\n", "# Classification report\n", "report = classification_report(y_test, y_pred)\n", "print('Classification Report:')\n", "print(report)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "57LnREhz-I-X", "outputId": "064695ae-967c-4e0c-e81f-56c56708d360" }, "execution_count": 16, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Accuracy: 0.70\n", "\n", "Confusion Matrix:\n", "[[18 9]\n", " [ 7 19]]\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " BLUE 0.72 0.67 0.69 27\n", " RED 0.68 0.73 0.70 26\n", "\n", " accuracy 0.70 53\n", " macro avg 0.70 0.70 0.70 53\n", "weighted avg 0.70 0.70 0.70 53\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Initialize a local repository" ], "metadata": { "id": "Xj4NjZshI8Z5" } }, { "cell_type": "code", "source": [ "# _, pkl_name = mkstemp(prefix=\"skops-\", suffix=\".pkl\")\n", "# local_repo = mkdtemp(prefix=\"skops-\")\n", "\n", "# with open(pkl_name, mode=\"bw\") as f:\n", "# sio.dump(model, file=f)\n", "\n", "pkl_name = \"model.pkl\"\n", "local_repo = \"mpl-id-s14-prediction\"\n", "\n", "with open(pkl_name, mode=\"bw\") as f:\n", " sio.dump(model, file=f)\n", "\n", "hub_utils.init(\n", " model=pkl_name,\n", " requirements=[f\"scikit-learn={sklearn.__version__}\"],\n", " dst=local_repo,\n", " task=\"tabular-classification\",\n", " data=X_test,\n", ")\n", "\n", "if \"__file__\" in locals(): # __file__ not defined during docs built\n", " # Add this script itself to the files to be uploaded for reproducibility\n", " hub_utils.add_files(__file__, dst=local_repo)" ], "metadata": { "id": "7ajeBC7hI3uj" }, "execution_count": 18, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Create model card" ], "metadata": { "id": "AJhkAc9NRzL0" } }, { "cell_type": "code", "source": [ "model_card = card.Card(model, metadata=card.metadata_from_config(Path(local_repo)))" ], "metadata": { "id": "HZFv0iwiR00h" }, "execution_count": 19, "outputs": [] }, { "cell_type": "code", "source": [ "limitations = (\"This model is only trained with MPL ID Season 14 data\")\n", "model_description = (\n", " \"This is a RandomForestClassifier model trained on mpl_id_s14 dataset.\"\n", ")\n", "model_card_authors = \"z4fL\"\n", "model_card.add(\n", " folded=False,\n", " **{\n", " \"Model Card Authors\": model_card_authors,\n", " \"Limitations\": limitations,\n", " \"Model description\": model_description,\n", " },\n", ")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7Fx0zmHXR41H", "outputId": "de418a76-ed81-4259-8e3b-fa75eb30925b" }, "execution_count": 20, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Card(\n", " model=RandomForestClassifier(max_depth=2..., n_estimators=50, random_state=91),\n", " metadata.library_name=sklearn,\n", " metadata.tags=['sklearn', 'skops', 'tabular-classification'],\n", " metadata.model_format=pickle,\n", " metadata.model_file=model.pkl,\n", " metadata.widget=[{...}],\n", " Model description=This is a RandomForestClassi...rained on mpl_id_s14 dataset.,\n", " Model description/Training Procedure/Hyperparameters=TableSection(19x2),\n", " Model description/Training Procedure/Model Plot=