{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "train_model.ipynb", "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": 23, "metadata": { "id": "kFAHrl4RTtV4" }, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestRegressor\n" ] }, { "cell_type": "code", "source": [ "wine = pd.read_csv(\"winequality-red.csv\")" ], "metadata": { "id": "PtRnEnZqUVz3" }, "execution_count": 24, "outputs": [] }, { "cell_type": "code", "source": [ "wine.describe()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 346 }, "id": "BEyAqxlzcY7K", "outputId": "c9a6c736-27c4-4fcf-a70c-bb9f63dbb021" }, "execution_count": 43, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
count1599.0000001599.0000001599.0000001599.0000001599.0000001599.0000001599.0000001599.0000001599.0000001599.0000001599.0000001599.000000
mean8.3196370.5278210.2709762.5388060.08746715.87492246.4677920.9967473.3111130.65814910.4229835.636023
std1.7410960.1790600.1948011.4099280.04706510.46015732.8953240.0018870.1543860.1695071.0656680.807569
min4.6000000.1200000.0000000.9000000.0120001.0000006.0000000.9900702.7400000.3300008.4000003.000000
25%7.1000000.3900000.0900001.9000000.0700007.00000022.0000000.9956003.2100000.5500009.5000005.000000
50%7.9000000.5200000.2600002.2000000.07900014.00000038.0000000.9967503.3100000.62000010.2000006.000000
75%9.2000000.6400000.4200002.6000000.09000021.00000062.0000000.9978353.4000000.73000011.1000006.000000
max15.9000001.5800001.00000015.5000000.61100072.000000289.0000001.0036904.0100002.00000014.9000008.000000
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " fixed acidity volatile acidity ... alcohol quality\n", "count 1599.000000 1599.000000 ... 1599.000000 1599.000000\n", "mean 8.319637 0.527821 ... 10.422983 5.636023\n", "std 1.741096 0.179060 ... 1.065668 0.807569\n", "min 4.600000 0.120000 ... 8.400000 3.000000\n", "25% 7.100000 0.390000 ... 9.500000 5.000000\n", "50% 7.900000 0.520000 ... 10.200000 6.000000\n", "75% 9.200000 0.640000 ... 11.100000 6.000000\n", "max 15.900000 1.580000 ... 14.900000 8.000000\n", "\n", "[8 rows x 12 columns]" ] }, "metadata": {}, "execution_count": 43 } ] }, { "cell_type": "code", "source": [ "wine.head()" ], "metadata": { "id": "xuERkgD3UZlx", "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "outputId": "8376f32a-b09b-4c7a-be86-3413a770c3b5" }, "execution_count": 25, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
07.40.700.001.90.07611.034.00.99783.510.569.45
17.80.880.002.60.09825.067.00.99683.200.689.85
27.80.760.042.30.09215.054.00.99703.260.659.85
311.20.280.561.90.07517.060.00.99803.160.589.86
47.40.700.001.90.07611.034.00.99783.510.569.45
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " fixed acidity volatile acidity citric acid ... sulphates alcohol quality\n", "0 7.4 0.70 0.00 ... 0.56 9.4 5\n", "1 7.8 0.88 0.00 ... 0.68 9.8 5\n", "2 7.8 0.76 0.04 ... 0.65 9.8 5\n", "3 11.2 0.28 0.56 ... 0.58 9.8 6\n", "4 7.4 0.70 0.00 ... 0.56 9.4 5\n", "\n", "[5 rows x 12 columns]" ] }, "metadata": {}, "execution_count": 25 } ] }, { "cell_type": "code", "source": [ "X = wine.drop('quality', axis = 1)\n", "y = wine['quality']" ], "metadata": { "id": "G2XlFDL1UuGU" }, "execution_count": 26, "outputs": [] }, { "cell_type": "code", "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)" ], "metadata": { "id": "fwQiTgseUZ8t" }, "execution_count": 27, "outputs": [] }, { "cell_type": "code", "source": [ "rfc = RandomForestRegressor(n_estimators=200)\n", "rfc.fit(X_train, y_train)\n", "rfc.score(X_test, y_test)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gfaWbU0XU8qx", "outputId": "c1360af2-235a-441b-89dc-ed8a1a3cd652" }, "execution_count": 28, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.452940101720804" ] }, "metadata": {}, "execution_count": 28 } ] }, { "cell_type": "code", "source": [ "preds = rfc.predict(X_test)" ], "metadata": { "id": "RS8QjOk6W3eW" }, "execution_count": 31, "outputs": [] }, { "cell_type": "code", "source": [ "df.iloc[0,:]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p6A3L6kcYLg4", "outputId": "27042b35-0070-4a8f-dc21-10e30ea4f03a" }, "execution_count": 34, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "fixed acidity 7.4000\n", "volatile acidity 0.7000\n", "citric acid 0.0000\n", "residual sugar 1.9000\n", "chlorides 0.0760\n", "free sulfur dioxide 11.0000\n", "total sulfur dioxide 34.0000\n", "density 0.9978\n", "pH 3.5100\n", "sulphates 0.5600\n", "alcohol 9.4000\n", "quality 5.0000\n", "Name: 0, dtype: float64" ] }, "metadata": {}, "execution_count": 34 } ] }, { "cell_type": "code", "source": [ "df_pred = pd.DataFrame.from_dict({\n", " 'fixed acidity': 7.4, \n", " 'volatile acidity': 0.7, \n", " 'citric acid': 0, \n", " 'residual sugar': 1.9,\n", " 'chlorides': 0.076, \n", " 'free sulfur dioxide': 11, \n", " 'total sulfur dioxide': 34, \n", " 'density':0.9978,\n", " 'pH': 3.51, \n", " 'sulphates': 0.56, \n", " 'alcohol':9.4\n", "}, orient='index').T" ], "metadata": { "id": "YYRmAoyJYGKR" }, "execution_count": 40, "outputs": [] }, { "cell_type": "code", "source": [ "df_pred" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 80 }, "id": "fFCqmU6SY_S5", "outputId": "4606cb07-1eb0-4b4e-8e7b-03eb505481fc" }, "execution_count": 41, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcohol
07.40.70.01.90.07611.034.00.99783.510.569.4
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " fixed acidity volatile acidity citric acid ... pH sulphates alcohol\n", "0 7.4 0.7 0.0 ... 3.51 0.56 9.4\n", "\n", "[1 rows x 11 columns]" ] }, "metadata": {}, "execution_count": 41 } ] }, { "cell_type": "code", "source": [ "rfc.predict(df_pred)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TbLBRotEYBOf", "outputId": "6ac50cb8-147b-4c96-bc39-2261b736973e" }, "execution_count": 42, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([5.025])" ] }, "metadata": {}, "execution_count": 42 } ] }, { "cell_type": "code", "source": [ "from joblib import dump, load\n", "dump(rfc, 'wine_pred.joblib') " ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Wh5wXQqbWHNK", "outputId": "84d59d4f-811b-4e1d-d182-267bbde56414" }, "execution_count": 32, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['wine_pred.joblib']" ] }, "metadata": {}, "execution_count": 32 } ] }, { "cell_type": "code", "source": [ "" ], "metadata": { "id": "BkfTMO4AXi4o" }, "execution_count": null, "outputs": [] } ] }