{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "train_model.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "kFAHrl4RTtV4"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestRegressor\n"
]
},
{
"cell_type": "code",
"source": [
"wine = pd.read_csv(\"winequality-red.csv\")"
],
"metadata": {
"id": "PtRnEnZqUVz3"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"wine.describe()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 346
},
"id": "BEyAqxlzcY7K",
"outputId": "c9a6c736-27c4-4fcf-a70c-bb9f63dbb021"
},
"execution_count": 43,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"
\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed acidity | \n",
" volatile acidity | \n",
" citric acid | \n",
" residual sugar | \n",
" chlorides | \n",
" free sulfur dioxide | \n",
" total sulfur dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
" 1599.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 8.319637 | \n",
" 0.527821 | \n",
" 0.270976 | \n",
" 2.538806 | \n",
" 0.087467 | \n",
" 15.874922 | \n",
" 46.467792 | \n",
" 0.996747 | \n",
" 3.311113 | \n",
" 0.658149 | \n",
" 10.422983 | \n",
" 5.636023 | \n",
"
\n",
" \n",
" std | \n",
" 1.741096 | \n",
" 0.179060 | \n",
" 0.194801 | \n",
" 1.409928 | \n",
" 0.047065 | \n",
" 10.460157 | \n",
" 32.895324 | \n",
" 0.001887 | \n",
" 0.154386 | \n",
" 0.169507 | \n",
" 1.065668 | \n",
" 0.807569 | \n",
"
\n",
" \n",
" min | \n",
" 4.600000 | \n",
" 0.120000 | \n",
" 0.000000 | \n",
" 0.900000 | \n",
" 0.012000 | \n",
" 1.000000 | \n",
" 6.000000 | \n",
" 0.990070 | \n",
" 2.740000 | \n",
" 0.330000 | \n",
" 8.400000 | \n",
" 3.000000 | \n",
"
\n",
" \n",
" 25% | \n",
" 7.100000 | \n",
" 0.390000 | \n",
" 0.090000 | \n",
" 1.900000 | \n",
" 0.070000 | \n",
" 7.000000 | \n",
" 22.000000 | \n",
" 0.995600 | \n",
" 3.210000 | \n",
" 0.550000 | \n",
" 9.500000 | \n",
" 5.000000 | \n",
"
\n",
" \n",
" 50% | \n",
" 7.900000 | \n",
" 0.520000 | \n",
" 0.260000 | \n",
" 2.200000 | \n",
" 0.079000 | \n",
" 14.000000 | \n",
" 38.000000 | \n",
" 0.996750 | \n",
" 3.310000 | \n",
" 0.620000 | \n",
" 10.200000 | \n",
" 6.000000 | \n",
"
\n",
" \n",
" 75% | \n",
" 9.200000 | \n",
" 0.640000 | \n",
" 0.420000 | \n",
" 2.600000 | \n",
" 0.090000 | \n",
" 21.000000 | \n",
" 62.000000 | \n",
" 0.997835 | \n",
" 3.400000 | \n",
" 0.730000 | \n",
" 11.100000 | \n",
" 6.000000 | \n",
"
\n",
" \n",
" max | \n",
" 15.900000 | \n",
" 1.580000 | \n",
" 1.000000 | \n",
" 15.500000 | \n",
" 0.611000 | \n",
" 72.000000 | \n",
" 289.000000 | \n",
" 1.003690 | \n",
" 4.010000 | \n",
" 2.000000 | \n",
" 14.900000 | \n",
" 8.000000 | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" fixed acidity volatile acidity ... alcohol quality\n",
"count 1599.000000 1599.000000 ... 1599.000000 1599.000000\n",
"mean 8.319637 0.527821 ... 10.422983 5.636023\n",
"std 1.741096 0.179060 ... 1.065668 0.807569\n",
"min 4.600000 0.120000 ... 8.400000 3.000000\n",
"25% 7.100000 0.390000 ... 9.500000 5.000000\n",
"50% 7.900000 0.520000 ... 10.200000 6.000000\n",
"75% 9.200000 0.640000 ... 11.100000 6.000000\n",
"max 15.900000 1.580000 ... 14.900000 8.000000\n",
"\n",
"[8 rows x 12 columns]"
]
},
"metadata": {},
"execution_count": 43
}
]
},
{
"cell_type": "code",
"source": [
"wine.head()"
],
"metadata": {
"id": "xuERkgD3UZlx",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"outputId": "8376f32a-b09b-4c7a-be86-3413a770c3b5"
},
"execution_count": 25,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed acidity | \n",
" volatile acidity | \n",
" citric acid | \n",
" residual sugar | \n",
" chlorides | \n",
" free sulfur dioxide | \n",
" total sulfur dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
" quality | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" 5 | \n",
"
\n",
" \n",
" 1 | \n",
" 7.8 | \n",
" 0.88 | \n",
" 0.00 | \n",
" 2.6 | \n",
" 0.098 | \n",
" 25.0 | \n",
" 67.0 | \n",
" 0.9968 | \n",
" 3.20 | \n",
" 0.68 | \n",
" 9.8 | \n",
" 5 | \n",
"
\n",
" \n",
" 2 | \n",
" 7.8 | \n",
" 0.76 | \n",
" 0.04 | \n",
" 2.3 | \n",
" 0.092 | \n",
" 15.0 | \n",
" 54.0 | \n",
" 0.9970 | \n",
" 3.26 | \n",
" 0.65 | \n",
" 9.8 | \n",
" 5 | \n",
"
\n",
" \n",
" 3 | \n",
" 11.2 | \n",
" 0.28 | \n",
" 0.56 | \n",
" 1.9 | \n",
" 0.075 | \n",
" 17.0 | \n",
" 60.0 | \n",
" 0.9980 | \n",
" 3.16 | \n",
" 0.58 | \n",
" 9.8 | \n",
" 6 | \n",
"
\n",
" \n",
" 4 | \n",
" 7.4 | \n",
" 0.70 | \n",
" 0.00 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
" 5 | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" fixed acidity volatile acidity citric acid ... sulphates alcohol quality\n",
"0 7.4 0.70 0.00 ... 0.56 9.4 5\n",
"1 7.8 0.88 0.00 ... 0.68 9.8 5\n",
"2 7.8 0.76 0.04 ... 0.65 9.8 5\n",
"3 11.2 0.28 0.56 ... 0.58 9.8 6\n",
"4 7.4 0.70 0.00 ... 0.56 9.4 5\n",
"\n",
"[5 rows x 12 columns]"
]
},
"metadata": {},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"source": [
"X = wine.drop('quality', axis = 1)\n",
"y = wine['quality']"
],
"metadata": {
"id": "G2XlFDL1UuGU"
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
],
"metadata": {
"id": "fwQiTgseUZ8t"
},
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"source": [
"rfc = RandomForestRegressor(n_estimators=200)\n",
"rfc.fit(X_train, y_train)\n",
"rfc.score(X_test, y_test)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gfaWbU0XU8qx",
"outputId": "c1360af2-235a-441b-89dc-ed8a1a3cd652"
},
"execution_count": 28,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.452940101720804"
]
},
"metadata": {},
"execution_count": 28
}
]
},
{
"cell_type": "code",
"source": [
"preds = rfc.predict(X_test)"
],
"metadata": {
"id": "RS8QjOk6W3eW"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"df.iloc[0,:]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "p6A3L6kcYLg4",
"outputId": "27042b35-0070-4a8f-dc21-10e30ea4f03a"
},
"execution_count": 34,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"fixed acidity 7.4000\n",
"volatile acidity 0.7000\n",
"citric acid 0.0000\n",
"residual sugar 1.9000\n",
"chlorides 0.0760\n",
"free sulfur dioxide 11.0000\n",
"total sulfur dioxide 34.0000\n",
"density 0.9978\n",
"pH 3.5100\n",
"sulphates 0.5600\n",
"alcohol 9.4000\n",
"quality 5.0000\n",
"Name: 0, dtype: float64"
]
},
"metadata": {},
"execution_count": 34
}
]
},
{
"cell_type": "code",
"source": [
"df_pred = pd.DataFrame.from_dict({\n",
" 'fixed acidity': 7.4, \n",
" 'volatile acidity': 0.7, \n",
" 'citric acid': 0, \n",
" 'residual sugar': 1.9,\n",
" 'chlorides': 0.076, \n",
" 'free sulfur dioxide': 11, \n",
" 'total sulfur dioxide': 34, \n",
" 'density':0.9978,\n",
" 'pH': 3.51, \n",
" 'sulphates': 0.56, \n",
" 'alcohol':9.4\n",
"}, orient='index').T"
],
"metadata": {
"id": "YYRmAoyJYGKR"
},
"execution_count": 40,
"outputs": []
},
{
"cell_type": "code",
"source": [
"df_pred"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 80
},
"id": "fFCqmU6SY_S5",
"outputId": "4606cb07-1eb0-4b4e-8e7b-03eb505481fc"
},
"execution_count": 41,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" fixed acidity | \n",
" volatile acidity | \n",
" citric acid | \n",
" residual sugar | \n",
" chlorides | \n",
" free sulfur dioxide | \n",
" total sulfur dioxide | \n",
" density | \n",
" pH | \n",
" sulphates | \n",
" alcohol | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 7.4 | \n",
" 0.7 | \n",
" 0.0 | \n",
" 1.9 | \n",
" 0.076 | \n",
" 11.0 | \n",
" 34.0 | \n",
" 0.9978 | \n",
" 3.51 | \n",
" 0.56 | \n",
" 9.4 | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" fixed acidity volatile acidity citric acid ... pH sulphates alcohol\n",
"0 7.4 0.7 0.0 ... 3.51 0.56 9.4\n",
"\n",
"[1 rows x 11 columns]"
]
},
"metadata": {},
"execution_count": 41
}
]
},
{
"cell_type": "code",
"source": [
"rfc.predict(df_pred)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TbLBRotEYBOf",
"outputId": "6ac50cb8-147b-4c96-bc39-2261b736973e"
},
"execution_count": 42,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([5.025])"
]
},
"metadata": {},
"execution_count": 42
}
]
},
{
"cell_type": "code",
"source": [
"from joblib import dump, load\n",
"dump(rfc, 'wine_pred.joblib') "
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Wh5wXQqbWHNK",
"outputId": "84d59d4f-811b-4e1d-d182-267bbde56414"
},
"execution_count": 32,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['wine_pred.joblib']"
]
},
"metadata": {},
"execution_count": 32
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "BkfTMO4AXi4o"
},
"execution_count": null,
"outputs": []
}
]
}