samyak152002 commited on
Commit
fa68b3f
1 Parent(s): 1b171bc

Upload catboost regressor.ipynb

Browse files
Files changed (1) hide show
  1. catboost regressor.ipynb +1011 -0
catboost regressor.ipynb ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "3723782b",
7
+ "metadata": {
8
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
9
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
10
+ "execution": {
11
+ "iopub.execute_input": "2022-09-05T13:02:09.644656Z",
12
+ "iopub.status.busy": "2022-09-05T13:02:09.643933Z",
13
+ "iopub.status.idle": "2022-09-05T13:02:09.657508Z",
14
+ "shell.execute_reply": "2022-09-05T13:02:09.656438Z"
15
+ },
16
+ "papermill": {
17
+ "duration": 0.023341,
18
+ "end_time": "2022-09-05T13:02:09.660113",
19
+ "exception": false,
20
+ "start_time": "2022-09-05T13:02:09.636772",
21
+ "status": "completed"
22
+ },
23
+ "tags": []
24
+ },
25
+ "outputs": [
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "/kaggle/input/nsutai/sample_submission.csv\n",
31
+ "/kaggle/input/nsutai/train.csv\n",
32
+ "/kaggle/input/nsutai/test.csv\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# This Python 3 environment comes with many helpful analytics libraries installed\n",
38
+ "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
39
+ "# For example, here's several helpful packages to load\n",
40
+ "\n",
41
+ "import numpy as np # linear algebra\n",
42
+ "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
43
+ "\n",
44
+ "# Input data files are available in the read-only \"../input/\" directory\n",
45
+ "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
46
+ "\n",
47
+ "import os\n",
48
+ "for dirname, _, filenames in os.walk('/kaggle/input'):\n",
49
+ " for filename in filenames:\n",
50
+ " print(os.path.join(dirname, filename))\n",
51
+ "\n",
52
+ "# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
53
+ "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 2,
59
+ "id": "909fe5b2",
60
+ "metadata": {
61
+ "execution": {
62
+ "iopub.execute_input": "2022-09-05T13:02:09.669974Z",
63
+ "iopub.status.busy": "2022-09-05T13:02:09.668381Z",
64
+ "iopub.status.idle": "2022-09-05T13:02:10.932568Z",
65
+ "shell.execute_reply": "2022-09-05T13:02:10.931603Z"
66
+ },
67
+ "papermill": {
68
+ "duration": 1.271004,
69
+ "end_time": "2022-09-05T13:02:10.934930",
70
+ "exception": false,
71
+ "start_time": "2022-09-05T13:02:09.663926",
72
+ "status": "completed"
73
+ },
74
+ "tags": []
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "from catboost import CatBoostRegressor\n",
79
+ "from sklearn.model_selection import ShuffleSplit, GridSearchCV\n",
80
+ "from sklearn.preprocessing import StandardScaler\n",
81
+ "from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict\n",
82
+ "import pandas as pd\n",
83
+ "import numpy as np\n",
84
+ "import seaborn as sns\n",
85
+ "import matplotlib.pyplot as plt"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 3,
91
+ "id": "1462638a",
92
+ "metadata": {
93
+ "execution": {
94
+ "iopub.execute_input": "2022-09-05T13:02:10.943879Z",
95
+ "iopub.status.busy": "2022-09-05T13:02:10.943596Z",
96
+ "iopub.status.idle": "2022-09-05T13:02:11.394396Z",
97
+ "shell.execute_reply": "2022-09-05T13:02:11.393459Z"
98
+ },
99
+ "papermill": {
100
+ "duration": 0.457904,
101
+ "end_time": "2022-09-05T13:02:11.396743",
102
+ "exception": false,
103
+ "start_time": "2022-09-05T13:02:10.938839",
104
+ "status": "completed"
105
+ },
106
+ "tags": []
107
+ },
108
+ "outputs": [],
109
+ "source": [
110
+ "df = pd.read_csv('../input/nsutai/train.csv')\n",
111
+ "# df = df.drop('name',axis=1)\n",
112
+ "# df = df.drop('sentry_object',axis=1)\n",
113
+ "df = df.drop('event',axis=1)\n",
114
+ "df = df.drop('id',axis=1)\n",
115
+ "terms = 0\n",
116
+ "total = 0\n",
117
+ "for i in df:\n",
118
+ " df[i].fillna(value=df[i].mean(), inplace=True)\n",
119
+ "df['px12'] = df.px1 * df.px2\n",
120
+ "df['Q12'] = df.q1 * df.q2\n",
121
+ "df['phi12'] = df.phi1 * df.phi2\n",
122
+ "df['eta12'] = df.eta1 * df.eta2\n",
123
+ "df['pt12'] = df.pt1 * df.pt2\n",
124
+ "df['E12'] = df.e1 * df.e2\n",
125
+ "df['pz_diff'] = df.pz1 - df.pz2\n",
126
+ "df['eta_diff'] = df.eta1 - df.eta2"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 4,
132
+ "id": "a9ed9929",
133
+ "metadata": {
134
+ "execution": {
135
+ "iopub.execute_input": "2022-09-05T13:02:11.405872Z",
136
+ "iopub.status.busy": "2022-09-05T13:02:11.405576Z",
137
+ "iopub.status.idle": "2022-09-05T13:02:11.426360Z",
138
+ "shell.execute_reply": "2022-09-05T13:02:11.425332Z"
139
+ },
140
+ "papermill": {
141
+ "duration": 0.02842,
142
+ "end_time": "2022-09-05T13:02:11.429254",
143
+ "exception": false,
144
+ "start_time": "2022-09-05T13:02:11.400834",
145
+ "status": "completed"
146
+ },
147
+ "tags": []
148
+ },
149
+ "outputs": [
150
+ {
151
+ "name": "stdout",
152
+ "output_type": "stream",
153
+ "text": [
154
+ "<class 'pandas.core.frame.DataFrame'>\n",
155
+ "RangeIndex: 69940 entries, 0 to 69939\n",
156
+ "Data columns (total 25 columns):\n",
157
+ " # Column Non-Null Count Dtype \n",
158
+ "--- ------ -------------- ----- \n",
159
+ " 0 e1 69940 non-null float64\n",
160
+ " 1 px1 69940 non-null float64\n",
161
+ " 2 py1 69940 non-null float64\n",
162
+ " 3 pz1 69940 non-null float64\n",
163
+ " 4 pt1 69940 non-null float64\n",
164
+ " 5 eta1 69940 non-null float64\n",
165
+ " 6 phi1 69940 non-null float64\n",
166
+ " 7 q1 69940 non-null int64 \n",
167
+ " 8 e2 69940 non-null float64\n",
168
+ " 9 px2 69940 non-null float64\n",
169
+ " 10 py2 69940 non-null float64\n",
170
+ " 11 pz2 69940 non-null float64\n",
171
+ " 12 pt2 69940 non-null float64\n",
172
+ " 13 eta2 69940 non-null float64\n",
173
+ " 14 phi2 69940 non-null float64\n",
174
+ " 15 q2 69940 non-null int64 \n",
175
+ " 16 mass 69940 non-null float64\n",
176
+ " 17 px12 69940 non-null float64\n",
177
+ " 18 Q12 69940 non-null int64 \n",
178
+ " 19 phi12 69940 non-null float64\n",
179
+ " 20 eta12 69940 non-null float64\n",
180
+ " 21 pt12 69940 non-null float64\n",
181
+ " 22 E12 69940 non-null float64\n",
182
+ " 23 pz_diff 69940 non-null float64\n",
183
+ " 24 eta_diff 69940 non-null float64\n",
184
+ "dtypes: float64(22), int64(3)\n",
185
+ "memory usage: 13.3 MB\n"
186
+ ]
187
+ }
188
+ ],
189
+ "source": [
190
+ "df.info()"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 5,
196
+ "id": "ff7f9418",
197
+ "metadata": {
198
+ "execution": {
199
+ "iopub.execute_input": "2022-09-05T13:02:11.438260Z",
200
+ "iopub.status.busy": "2022-09-05T13:02:11.437997Z",
201
+ "iopub.status.idle": "2022-09-05T13:02:11.483587Z",
202
+ "shell.execute_reply": "2022-09-05T13:02:11.482534Z"
203
+ },
204
+ "papermill": {
205
+ "duration": 0.05382,
206
+ "end_time": "2022-09-05T13:02:11.487122",
207
+ "exception": false,
208
+ "start_time": "2022-09-05T13:02:11.433302",
209
+ "status": "completed"
210
+ },
211
+ "tags": []
212
+ },
213
+ "outputs": [],
214
+ "source": [
215
+ "Xn_data = np.array(df.drop('mass', axis = 1).values)\n",
216
+ "# Normalize X_data\n",
217
+ "X_data = (Xn_data - np.mean(Xn_data)) / np.std(Xn_data)\n",
218
+ "# y_data = np.log(np.array(df[\"mass\"]))\n",
219
+ "from sklearn.preprocessing import StandardScaler"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 6,
225
+ "id": "d33ee3a8",
226
+ "metadata": {
227
+ "execution": {
228
+ "iopub.execute_input": "2022-09-05T13:02:11.501842Z",
229
+ "iopub.status.busy": "2022-09-05T13:02:11.501512Z",
230
+ "iopub.status.idle": "2022-09-05T13:02:11.576250Z",
231
+ "shell.execute_reply": "2022-09-05T13:02:11.574558Z"
232
+ },
233
+ "papermill": {
234
+ "duration": 0.084845,
235
+ "end_time": "2022-09-05T13:02:11.580949",
236
+ "exception": false,
237
+ "start_time": "2022-09-05T13:02:11.496104",
238
+ "status": "completed"
239
+ },
240
+ "tags": []
241
+ },
242
+ "outputs": [],
243
+ "source": [
244
+ "X = df.drop('mass',axis=1)\n",
245
+ "# from sklearn.preprocessing import StandardScaler\n",
246
+ "# sc = StandardScaler()\n",
247
+ "# X = sc.fit_transform(X)\n",
248
+ "y = df['mass'].values\n",
249
+ "\n",
250
+ "X_train, X_test, y_train, y_test = train_test_split(X_data,y,test_size=0.2,random_state=33)\n",
251
+ "ss = StandardScaler()\n",
252
+ "X_train = ss.fit_transform(X_train)\n",
253
+ "X_test = ss.fit_transform(X_test)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 7,
259
+ "id": "0fea2318",
260
+ "metadata": {
261
+ "execution": {
262
+ "iopub.execute_input": "2022-09-05T13:02:11.598458Z",
263
+ "iopub.status.busy": "2022-09-05T13:02:11.597860Z",
264
+ "iopub.status.idle": "2022-09-05T13:02:11.612775Z",
265
+ "shell.execute_reply": "2022-09-05T13:02:11.611552Z"
266
+ },
267
+ "papermill": {
268
+ "duration": 0.027089,
269
+ "end_time": "2022-09-05T13:02:11.616455",
270
+ "exception": false,
271
+ "start_time": "2022-09-05T13:02:11.589366",
272
+ "status": "completed"
273
+ },
274
+ "tags": []
275
+ },
276
+ "outputs": [
277
+ {
278
+ "data": {
279
+ "text/plain": [
280
+ "array([[ 1.54406036e-03, 5.00046796e-02, 9.72291485e-01, ...,\n",
281
+ " 3.46918585e-03, -4.15573889e-02, -7.97763559e-03],\n",
282
+ " [ 1.94422910e-04, -3.33701351e-01, -1.21056344e+00, ...,\n",
283
+ " 3.65737676e-03, -1.13711783e-01, 2.60815651e-03],\n",
284
+ " [ 2.38266960e-03, 5.19004764e-01, 8.78117565e-01, ...,\n",
285
+ " 2.62745940e-03, 8.97889389e-02, -2.15806071e-02],\n",
286
+ " ...,\n",
287
+ " [-1.40730483e-04, -1.79267428e-01, 1.03092574e+00, ...,\n",
288
+ " 3.74511846e-03, -2.07740905e+00, 7.78566082e-03],\n",
289
+ " [-8.58332036e-04, -7.98378259e-01, -1.12113137e+00, ...,\n",
290
+ " 2.87497174e-03, 4.31631201e-01, 8.14307267e-04],\n",
291
+ " [ 1.02554766e-03, -1.28674581e+00, -1.34946525e-01, ...,\n",
292
+ " 3.44007559e-03, 2.95152031e-03, -5.08424031e-04]])"
293
+ ]
294
+ },
295
+ "execution_count": 7,
296
+ "metadata": {},
297
+ "output_type": "execute_result"
298
+ }
299
+ ],
300
+ "source": [
301
+ "X_train"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": 8,
307
+ "id": "79f26373",
308
+ "metadata": {
309
+ "execution": {
310
+ "iopub.execute_input": "2022-09-05T13:02:11.633832Z",
311
+ "iopub.status.busy": "2022-09-05T13:02:11.633551Z",
312
+ "iopub.status.idle": "2022-09-05T13:02:11.638935Z",
313
+ "shell.execute_reply": "2022-09-05T13:02:11.638026Z"
314
+ },
315
+ "papermill": {
316
+ "duration": 0.01656,
317
+ "end_time": "2022-09-05T13:02:11.641678",
318
+ "exception": false,
319
+ "start_time": "2022-09-05T13:02:11.625118",
320
+ "status": "completed"
321
+ },
322
+ "tags": []
323
+ },
324
+ "outputs": [
325
+ {
326
+ "name": "stdout",
327
+ "output_type": "stream",
328
+ "text": [
329
+ "Location of categorical columns : ['q1', 'q2', 'Q12']\n"
330
+ ]
331
+ }
332
+ ],
333
+ "source": [
334
+ "#List of categorical columns\n",
335
+ "# categoricalcolumns = X.select_dtypes(include=[\"object\"]).columns.tolist()\n",
336
+ "# print(\"Names of categorical columns : \", categoricalcolumns)\n",
337
+ "#Get location of categorical columns\n",
338
+ "cat_features = ['q1','q2','Q12']\n",
339
+ "print(\"Location of categorical columns : \",cat_features)\n"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": 9,
345
+ "id": "255a87ee",
346
+ "metadata": {
347
+ "execution": {
348
+ "iopub.execute_input": "2022-09-05T13:02:11.651657Z",
349
+ "iopub.status.busy": "2022-09-05T13:02:11.651374Z",
350
+ "iopub.status.idle": "2022-09-05T13:02:12.403415Z",
351
+ "shell.execute_reply": "2022-09-05T13:02:12.402465Z"
352
+ },
353
+ "papermill": {
354
+ "duration": 0.760088,
355
+ "end_time": "2022-09-05T13:02:12.406197",
356
+ "exception": false,
357
+ "start_time": "2022-09-05T13:02:11.646109",
358
+ "status": "completed"
359
+ },
360
+ "tags": []
361
+ },
362
+ "outputs": [],
363
+ "source": [
364
+ "# importing Pool\n",
365
+ "from catboost import Pool\n",
366
+ "#Creating pool object for train dataset. we give information of categorical fetures to parameter cat_fetaures\n",
367
+ "train_data = Pool(data=X_train,\n",
368
+ " label=y_train,\n",
369
+ "# cat_features=cat_features\n",
370
+ " )\n",
371
+ "#Creating pool object for test dataset\n",
372
+ "test_data = Pool(data=X_test,\n",
373
+ " label=y_test,\n",
374
+ "# cat_features=cat_features\n",
375
+ " )"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": 10,
381
+ "id": "513ae7cc",
382
+ "metadata": {
383
+ "execution": {
384
+ "iopub.execute_input": "2022-09-05T13:02:12.420374Z",
385
+ "iopub.status.busy": "2022-09-05T13:02:12.420051Z",
386
+ "iopub.status.idle": "2022-09-05T13:02:12.423954Z",
387
+ "shell.execute_reply": "2022-09-05T13:02:12.423122Z"
388
+ },
389
+ "papermill": {
390
+ "duration": 0.015038,
391
+ "end_time": "2022-09-05T13:02:12.427807",
392
+ "exception": false,
393
+ "start_time": "2022-09-05T13:02:12.412769",
394
+ "status": "completed"
395
+ },
396
+ "tags": []
397
+ },
398
+ "outputs": [],
399
+ "source": [
400
+ "# print(np.sqrt(error))"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 11,
406
+ "id": "f6829920",
407
+ "metadata": {
408
+ "execution": {
409
+ "iopub.execute_input": "2022-09-05T13:02:12.441091Z",
410
+ "iopub.status.busy": "2022-09-05T13:02:12.440796Z",
411
+ "iopub.status.idle": "2022-09-05T14:45:06.059945Z",
412
+ "shell.execute_reply": "2022-09-05T14:45:06.058899Z"
413
+ },
414
+ "papermill": {
415
+ "duration": 6173.62882,
416
+ "end_time": "2022-09-05T14:45:06.062699",
417
+ "exception": false,
418
+ "start_time": "2022-09-05T13:02:12.433879",
419
+ "status": "completed"
420
+ },
421
+ "tags": []
422
+ },
423
+ "outputs": [
424
+ {
425
+ "name": "stdout",
426
+ "output_type": "stream",
427
+ "text": [
428
+ "Learning rate set to 0.011444\n",
429
+ "0:\tlearn: 1.0109453\ttotal: 25.3ms\tremaining: 8m 25s\n",
430
+ "5000:\tlearn: 0.2627569\ttotal: 1m 41s\tremaining: 5m 3s\n",
431
+ "10000:\tlearn: 0.0976646\ttotal: 3m 22s\tremaining: 3m 22s\n",
432
+ "15000:\tlearn: 0.0396183\ttotal: 5m 4s\tremaining: 1m 41s\n",
433
+ "19999:\tlearn: 0.0169753\ttotal: 6m 44s\tremaining: 0us\n",
434
+ "Learning rate set to 0.008715\n",
435
+ "0:\tlearn: 1.0111765\ttotal: 32ms\tremaining: 10m 39s\n",
436
+ "5000:\tlearn: 0.5784365\ttotal: 12m 45s\tremaining: 38m 15s\n",
437
+ "10000:\tlearn: 0.5133532\ttotal: 25m 6s\tremaining: 25m 5s\n",
438
+ "15000:\tlearn: 0.4772074\ttotal: 37m 26s\tremaining: 12m 28s\n",
439
+ "19999:\tlearn: 0.4527199\ttotal: 49m 6s\tremaining: 0us\n",
440
+ "Learning rate set to 0.008715\n",
441
+ "0:\tlearn: 1.0070361\ttotal: 76.6ms\tremaining: 25m 31s\n",
442
+ "5000:\tlearn: 0.5676638\ttotal: 12m 51s\tremaining: 38m 33s\n",
443
+ "10000:\tlearn: 0.5046212\ttotal: 23m 32s\tremaining: 23m 32s\n",
444
+ "15000:\tlearn: 0.4696909\ttotal: 34m 21s\tremaining: 11m 26s\n",
445
+ "19999:\tlearn: 0.4463740\ttotal: 45m 11s\tremaining: 0us\n",
446
+ "0.8447744377236096\n"
447
+ ]
448
+ }
449
+ ],
450
+ "source": [
451
+ "catbr = CatBoostRegressor(verbose=5000,task_type='GPU',loss_function='RMSE',iterations = 20000,depth = 12).fit(X_train,y_train)\n",
452
+ "# R2CV = cross_val_score(catbr,X_test,y_test,cv=3,scoring=\"r2\").mean()\n",
453
+ "error = -cross_val_score(catbr,X_test,y_test,cv=2,scoring=\"neg_mean_squared_error\").mean()\n",
454
+ "# print(R2CV)\n",
455
+ "print(np.sqrt(error))"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": 12,
461
+ "id": "96a52e15",
462
+ "metadata": {
463
+ "execution": {
464
+ "iopub.execute_input": "2022-09-05T14:45:06.074785Z",
465
+ "iopub.status.busy": "2022-09-05T14:45:06.074091Z",
466
+ "iopub.status.idle": "2022-09-05T14:45:06.077959Z",
467
+ "shell.execute_reply": "2022-09-05T14:45:06.077027Z"
468
+ },
469
+ "papermill": {
470
+ "duration": 0.011987,
471
+ "end_time": "2022-09-05T14:45:06.080009",
472
+ "exception": false,
473
+ "start_time": "2022-09-05T14:45:06.068022",
474
+ "status": "completed"
475
+ },
476
+ "tags": []
477
+ },
478
+ "outputs": [],
479
+ "source": [
480
+ "# params = {\n",
481
+ " \n",
482
+ "# \"depth\": [2, 3, 4, 5, 6],\n",
483
+ "# \"learning_rate\": [0.1, 0.01, 0.5]\n",
484
+ "# }"
485
+ ]
486
+ },
487
+ {
488
+ "cell_type": "code",
489
+ "execution_count": 13,
490
+ "id": "58836338",
491
+ "metadata": {
492
+ "execution": {
493
+ "iopub.execute_input": "2022-09-05T14:45:06.091493Z",
494
+ "iopub.status.busy": "2022-09-05T14:45:06.091219Z",
495
+ "iopub.status.idle": "2022-09-05T14:45:06.095298Z",
496
+ "shell.execute_reply": "2022-09-05T14:45:06.094449Z"
497
+ },
498
+ "papermill": {
499
+ "duration": 0.012206,
500
+ "end_time": "2022-09-05T14:45:06.097198",
501
+ "exception": false,
502
+ "start_time": "2022-09-05T14:45:06.084992",
503
+ "status": "completed"
504
+ },
505
+ "tags": []
506
+ },
507
+ "outputs": [],
508
+ "source": [
509
+ "# cv = GridSearchCV(catbr, params, cv=10, verbose=True).fit(X_train, y_train)\n",
510
+ "# print(cv.best_params_)"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": 14,
516
+ "id": "81fb1672",
517
+ "metadata": {
518
+ "execution": {
519
+ "iopub.execute_input": "2022-09-05T14:45:06.107939Z",
520
+ "iopub.status.busy": "2022-09-05T14:45:06.107678Z",
521
+ "iopub.status.idle": "2022-09-05T14:45:06.111581Z",
522
+ "shell.execute_reply": "2022-09-05T14:45:06.110646Z"
523
+ },
524
+ "papermill": {
525
+ "duration": 0.011411,
526
+ "end_time": "2022-09-05T14:45:06.113552",
527
+ "exception": false,
528
+ "start_time": "2022-09-05T14:45:06.102141",
529
+ "status": "completed"
530
+ },
531
+ "tags": []
532
+ },
533
+ "outputs": [],
534
+ "source": [
535
+ "# catbrtuned = CatBoostRegressor(depth=5,learning_rate=0.01,verbose=False,task_type='GPU').fit(X_train,y_train)\n",
536
+ "\n",
537
+ "# # R2CVtuned = cross_val_score(catbrtuned,X_test,y_test,cv=15,scoring=\"r2\").mean()\n",
538
+ "# # print(R2CVtuned)\n",
539
+ "# errortuned = -cross_val_score(catbrtuned,X_test,y_test,cv=20,scoring=\"neg_mean_squared_error\").mean()\n",
540
+ "# print(np.sqrt(errortuned))"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": 15,
546
+ "id": "9654095f",
547
+ "metadata": {
548
+ "execution": {
549
+ "iopub.execute_input": "2022-09-05T14:45:06.124164Z",
550
+ "iopub.status.busy": "2022-09-05T14:45:06.123909Z",
551
+ "iopub.status.idle": "2022-09-05T14:45:06.130345Z",
552
+ "shell.execute_reply": "2022-09-05T14:45:06.129490Z"
553
+ },
554
+ "papermill": {
555
+ "duration": 0.013934,
556
+ "end_time": "2022-09-05T14:45:06.132270",
557
+ "exception": false,
558
+ "start_time": "2022-09-05T14:45:06.118336",
559
+ "status": "completed"
560
+ },
561
+ "tags": []
562
+ },
563
+ "outputs": [
564
+ {
565
+ "data": {
566
+ "text/plain": [
567
+ "array([[ 1.01680714e-02, 6.52076470e-01, -1.71361082e+00, ...,\n",
568
+ " 1.49118772e-03, 1.63726624e-01, 3.10255375e-03],\n",
569
+ " [ 1.24653160e-02, 1.40030095e-01, 9.08963077e-01, ...,\n",
570
+ " 6.54264303e-04, 1.87932998e+00, -2.37996666e-02],\n",
571
+ " [ 1.56609992e+00, 3.08161424e-01, -3.38753181e-01, ...,\n",
572
+ " -4.62755342e-02, 1.25592922e+00, -1.43120765e-02],\n",
573
+ " ...,\n",
574
+ " [ 8.24693308e-03, -1.65144148e-01, -1.91790940e+00, ...,\n",
575
+ " 1.21247452e-04, 1.06549128e+00, 3.83478623e-03],\n",
576
+ " [ 9.59390026e-05, -1.42118970e+00, -6.56800351e-01, ...,\n",
577
+ " 5.07400206e-04, -1.61068698e+00, 1.05205829e-03],\n",
578
+ " [ 2.32941353e-03, 1.13598956e-01, -1.35239977e-01, ...,\n",
579
+ " 8.06554653e-04, -4.00680340e-01, -5.26439826e-03]])"
580
+ ]
581
+ },
582
+ "execution_count": 15,
583
+ "metadata": {},
584
+ "output_type": "execute_result"
585
+ }
586
+ ],
587
+ "source": [
588
+ "X_test"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": 16,
594
+ "id": "0bb3dd56",
595
+ "metadata": {
596
+ "execution": {
597
+ "iopub.execute_input": "2022-09-05T14:45:06.143546Z",
598
+ "iopub.status.busy": "2022-09-05T14:45:06.143260Z",
599
+ "iopub.status.idle": "2022-09-05T14:45:06.354494Z",
600
+ "shell.execute_reply": "2022-09-05T14:45:06.353489Z"
601
+ },
602
+ "papermill": {
603
+ "duration": 0.219067,
604
+ "end_time": "2022-09-05T14:45:06.356593",
605
+ "exception": false,
606
+ "start_time": "2022-09-05T14:45:06.137526",
607
+ "status": "completed"
608
+ },
609
+ "tags": []
610
+ },
611
+ "outputs": [
612
+ {
613
+ "data": {
614
+ "text/html": [
615
+ "<div>\n",
616
+ "<style scoped>\n",
617
+ " .dataframe tbody tr th:only-of-type {\n",
618
+ " vertical-align: middle;\n",
619
+ " }\n",
620
+ "\n",
621
+ " .dataframe tbody tr th {\n",
622
+ " vertical-align: top;\n",
623
+ " }\n",
624
+ "\n",
625
+ " .dataframe thead th {\n",
626
+ " text-align: right;\n",
627
+ " }\n",
628
+ "</style>\n",
629
+ "<table border=\"1\" class=\"dataframe\">\n",
630
+ " <thead>\n",
631
+ " <tr style=\"text-align: right;\">\n",
632
+ " <th></th>\n",
633
+ " <th>e1</th>\n",
634
+ " <th>px1</th>\n",
635
+ " <th>py1</th>\n",
636
+ " <th>pz1</th>\n",
637
+ " <th>pt1</th>\n",
638
+ " <th>eta1</th>\n",
639
+ " <th>phi1</th>\n",
640
+ " <th>q1</th>\n",
641
+ " <th>e2</th>\n",
642
+ " <th>px2</th>\n",
643
+ " <th>py2</th>\n",
644
+ " <th>pz2</th>\n",
645
+ " <th>pt2</th>\n",
646
+ " <th>eta2</th>\n",
647
+ " <th>phi2</th>\n",
648
+ " <th>q2</th>\n",
649
+ " </tr>\n",
650
+ " </thead>\n",
651
+ " <tbody>\n",
652
+ " <tr>\n",
653
+ " <th>0</th>\n",
654
+ " <td>3.900910</td>\n",
655
+ " <td>24.920700</td>\n",
656
+ " <td>-1.102674</td>\n",
657
+ " <td>0.702628</td>\n",
658
+ " <td>24.98810</td>\n",
659
+ " <td>-42.967949</td>\n",
660
+ " <td>-0.073422</td>\n",
661
+ " <td>-1</td>\n",
662
+ " <td>0.726854</td>\n",
663
+ " <td>0.715108</td>\n",
664
+ " <td>-0.374241</td>\n",
665
+ " <td>-1.351969</td>\n",
666
+ " <td>-0.654767</td>\n",
667
+ " <td>0.674536</td>\n",
668
+ " <td>-0.987639</td>\n",
669
+ " <td>-1</td>\n",
670
+ " </tr>\n",
671
+ " <tr>\n",
672
+ " <th>1</th>\n",
673
+ " <td>-3.401169</td>\n",
674
+ " <td>-7.007870</td>\n",
675
+ " <td>0.581511</td>\n",
676
+ " <td>0.321741</td>\n",
677
+ " <td>8.03705</td>\n",
678
+ " <td>-0.162631</td>\n",
679
+ " <td>2.629960</td>\n",
680
+ " <td>1</td>\n",
681
+ " <td>36.603944</td>\n",
682
+ " <td>0.698290</td>\n",
683
+ " <td>0.936158</td>\n",
684
+ " <td>-0.772930</td>\n",
685
+ " <td>0.828964</td>\n",
686
+ " <td>0.647576</td>\n",
687
+ " <td>-0.532239</td>\n",
688
+ " <td>-1</td>\n",
689
+ " </tr>\n",
690
+ " <tr>\n",
691
+ " <th>2</th>\n",
692
+ " <td>0.804254</td>\n",
693
+ " <td>-1.219560</td>\n",
694
+ " <td>1.557146</td>\n",
695
+ " <td>0.949164</td>\n",
696
+ " <td>12.61310</td>\n",
697
+ " <td>0.337772</td>\n",
698
+ " <td>-1.667640</td>\n",
699
+ " <td>1</td>\n",
700
+ " <td>-25.401295</td>\n",
701
+ " <td>0.615361</td>\n",
702
+ " <td>0.550970</td>\n",
703
+ " <td>0.397435</td>\n",
704
+ " <td>-0.525276</td>\n",
705
+ " <td>0.723704</td>\n",
706
+ " <td>0.481615</td>\n",
707
+ " <td>1</td>\n",
708
+ " </tr>\n",
709
+ " <tr>\n",
710
+ " <th>3</th>\n",
711
+ " <td>1.073530</td>\n",
712
+ " <td>-2.721070</td>\n",
713
+ " <td>1.258642</td>\n",
714
+ " <td>0.268368</td>\n",
715
+ " <td>2.75403</td>\n",
716
+ " <td>3.670449</td>\n",
717
+ " <td>-2.986720</td>\n",
718
+ " <td>-1</td>\n",
719
+ " <td>-42.101333</td>\n",
720
+ " <td>0.741327</td>\n",
721
+ " <td>0.236630</td>\n",
722
+ " <td>-0.151015</td>\n",
723
+ " <td>0.472083</td>\n",
724
+ " <td>0.974790</td>\n",
725
+ " <td>0.356768</td>\n",
726
+ " <td>-1</td>\n",
727
+ " </tr>\n",
728
+ " <tr>\n",
729
+ " <th>4</th>\n",
730
+ " <td>0.704689</td>\n",
731
+ " <td>16.055100</td>\n",
732
+ " <td>1.471738</td>\n",
733
+ " <td>-0.863552</td>\n",
734
+ " <td>17.32390</td>\n",
735
+ " <td>2.430194</td>\n",
736
+ " <td>0.385110</td>\n",
737
+ " <td>-1</td>\n",
738
+ " <td>-1.048602</td>\n",
739
+ " <td>0.958089</td>\n",
740
+ " <td>0.268982</td>\n",
741
+ " <td>-0.264064</td>\n",
742
+ " <td>-0.969478</td>\n",
743
+ " <td>0.554407</td>\n",
744
+ " <td>-0.101867</td>\n",
745
+ " <td>-1</td>\n",
746
+ " </tr>\n",
747
+ " <tr>\n",
748
+ " <th>...</th>\n",
749
+ " <td>...</td>\n",
750
+ " <td>...</td>\n",
751
+ " <td>...</td>\n",
752
+ " <td>...</td>\n",
753
+ " <td>...</td>\n",
754
+ " <td>...</td>\n",
755
+ " <td>...</td>\n",
756
+ " <td>...</td>\n",
757
+ " <td>...</td>\n",
758
+ " <td>...</td>\n",
759
+ " <td>...</td>\n",
760
+ " <td>...</td>\n",
761
+ " <td>...</td>\n",
762
+ " <td>...</td>\n",
763
+ " <td>...</td>\n",
764
+ " <td>...</td>\n",
765
+ " </tr>\n",
766
+ " <tr>\n",
767
+ " <th>29970</th>\n",
768
+ " <td>-0.258829</td>\n",
769
+ " <td>-10.658000</td>\n",
770
+ " <td>-0.507161</td>\n",
771
+ " <td>0.228920</td>\n",
772
+ " <td>11.84330</td>\n",
773
+ " <td>0.296608</td>\n",
774
+ " <td>2.690370</td>\n",
775
+ " <td>-1</td>\n",
776
+ " <td>-4.251468</td>\n",
777
+ " <td>0.813939</td>\n",
778
+ " <td>73.667501</td>\n",
779
+ " <td>1.158823</td>\n",
780
+ " <td>-0.148324</td>\n",
781
+ " <td>0.954442</td>\n",
782
+ " <td>0.388798</td>\n",
783
+ " <td>1</td>\n",
784
+ " </tr>\n",
785
+ " <tr>\n",
786
+ " <th>29971</th>\n",
787
+ " <td>-0.405636</td>\n",
788
+ " <td>7.854990</td>\n",
789
+ " <td>1.024085</td>\n",
790
+ " <td>-0.860718</td>\n",
791
+ " <td>17.05020</td>\n",
792
+ " <td>-0.456347</td>\n",
793
+ " <td>1.092010</td>\n",
794
+ " <td>1</td>\n",
795
+ " <td>-1.616737</td>\n",
796
+ " <td>0.550067</td>\n",
797
+ " <td>1.162549</td>\n",
798
+ " <td>-0.311051</td>\n",
799
+ " <td>-0.897992</td>\n",
800
+ " <td>0.772986</td>\n",
801
+ " <td>-0.527932</td>\n",
802
+ " <td>1</td>\n",
803
+ " </tr>\n",
804
+ " <tr>\n",
805
+ " <th>29972</th>\n",
806
+ " <td>1.719597</td>\n",
807
+ " <td>-3.273500</td>\n",
808
+ " <td>1.397346</td>\n",
809
+ " <td>0.577056</td>\n",
810
+ " <td>3.28801</td>\n",
811
+ " <td>-1.215898</td>\n",
812
+ " <td>-3.047630</td>\n",
813
+ " <td>1</td>\n",
814
+ " <td>0.628021</td>\n",
815
+ " <td>0.542946</td>\n",
816
+ " <td>0.143970</td>\n",
817
+ " <td>-0.731390</td>\n",
818
+ " <td>-0.307851</td>\n",
819
+ " <td>0.734795</td>\n",
820
+ " <td>-0.766423</td>\n",
821
+ " <td>-1</td>\n",
822
+ " </tr>\n",
823
+ " <tr>\n",
824
+ " <th>29973</th>\n",
825
+ " <td>1.764202</td>\n",
826
+ " <td>11.352600</td>\n",
827
+ " <td>0.815074</td>\n",
828
+ " <td>0.930537</td>\n",
829
+ " <td>16.43280</td>\n",
830
+ " <td>-0.126448</td>\n",
831
+ " <td>0.808132</td>\n",
832
+ " <td>-1</td>\n",
833
+ " <td>-1.111305</td>\n",
834
+ " <td>0.932663</td>\n",
835
+ " <td>-0.063397</td>\n",
836
+ " <td>-0.153664</td>\n",
837
+ " <td>-0.999193</td>\n",
838
+ " <td>0.545050</td>\n",
839
+ " <td>0.121883</td>\n",
840
+ " <td>1</td>\n",
841
+ " </tr>\n",
842
+ " <tr>\n",
843
+ " <th>29974</th>\n",
844
+ " <td>4.601752</td>\n",
845
+ " <td>0.886162</td>\n",
846
+ " <td>0.556092</td>\n",
847
+ " <td>0.994464</td>\n",
848
+ " <td>5.55010</td>\n",
849
+ " <td>-2.074547</td>\n",
850
+ " <td>1.410440</td>\n",
851
+ " <td>1</td>\n",
852
+ " <td>-3.578146</td>\n",
853
+ " <td>0.628637</td>\n",
854
+ " <td>-9.284831</td>\n",
855
+ " <td>0.425168</td>\n",
856
+ " <td>0.147698</td>\n",
857
+ " <td>0.555115</td>\n",
858
+ " <td>0.964454</td>\n",
859
+ " <td>-1</td>\n",
860
+ " </tr>\n",
861
+ " </tbody>\n",
862
+ "</table>\n",
863
+ "<p>29975 rows × 16 columns</p>\n",
864
+ "</div>"
865
+ ],
866
+ "text/plain": [
867
+ " e1 px1 py1 pz1 pt1 eta1 phi1 \\\n",
868
+ "0 3.900910 24.920700 -1.102674 0.702628 24.98810 -42.967949 -0.073422 \n",
869
+ "1 -3.401169 -7.007870 0.581511 0.321741 8.03705 -0.162631 2.629960 \n",
870
+ "2 0.804254 -1.219560 1.557146 0.949164 12.61310 0.337772 -1.667640 \n",
871
+ "3 1.073530 -2.721070 1.258642 0.268368 2.75403 3.670449 -2.986720 \n",
872
+ "4 0.704689 16.055100 1.471738 -0.863552 17.32390 2.430194 0.385110 \n",
873
+ "... ... ... ... ... ... ... ... \n",
874
+ "29970 -0.258829 -10.658000 -0.507161 0.228920 11.84330 0.296608 2.690370 \n",
875
+ "29971 -0.405636 7.854990 1.024085 -0.860718 17.05020 -0.456347 1.092010 \n",
876
+ "29972 1.719597 -3.273500 1.397346 0.577056 3.28801 -1.215898 -3.047630 \n",
877
+ "29973 1.764202 11.352600 0.815074 0.930537 16.43280 -0.126448 0.808132 \n",
878
+ "29974 4.601752 0.886162 0.556092 0.994464 5.55010 -2.074547 1.410440 \n",
879
+ "\n",
880
+ " q1 e2 px2 py2 pz2 pt2 eta2 \\\n",
881
+ "0 -1 0.726854 0.715108 -0.374241 -1.351969 -0.654767 0.674536 \n",
882
+ "1 1 36.603944 0.698290 0.936158 -0.772930 0.828964 0.647576 \n",
883
+ "2 1 -25.401295 0.615361 0.550970 0.397435 -0.525276 0.723704 \n",
884
+ "3 -1 -42.101333 0.741327 0.236630 -0.151015 0.472083 0.974790 \n",
885
+ "4 -1 -1.048602 0.958089 0.268982 -0.264064 -0.969478 0.554407 \n",
886
+ "... .. ... ... ... ... ... ... \n",
887
+ "29970 -1 -4.251468 0.813939 73.667501 1.158823 -0.148324 0.954442 \n",
888
+ "29971 1 -1.616737 0.550067 1.162549 -0.311051 -0.897992 0.772986 \n",
889
+ "29972 1 0.628021 0.542946 0.143970 -0.731390 -0.307851 0.734795 \n",
890
+ "29973 -1 -1.111305 0.932663 -0.063397 -0.153664 -0.999193 0.545050 \n",
891
+ "29974 1 -3.578146 0.628637 -9.284831 0.425168 0.147698 0.555115 \n",
892
+ "\n",
893
+ " phi2 q2 \n",
894
+ "0 -0.987639 -1 \n",
895
+ "1 -0.532239 -1 \n",
896
+ "2 0.481615 1 \n",
897
+ "3 0.356768 -1 \n",
898
+ "4 -0.101867 -1 \n",
899
+ "... ... .. \n",
900
+ "29970 0.388798 1 \n",
901
+ "29971 -0.527932 1 \n",
902
+ "29972 -0.766423 -1 \n",
903
+ "29973 0.121883 1 \n",
904
+ "29974 0.964454 -1 \n",
905
+ "\n",
906
+ "[29975 rows x 16 columns]"
907
+ ]
908
+ },
909
+ "execution_count": 16,
910
+ "metadata": {},
911
+ "output_type": "execute_result"
912
+ }
913
+ ],
914
+ "source": [
915
+ "test=pd.read_csv(\"../input/nsutai/test.csv\")\n",
916
+ "\n",
917
+ "X_test=test.drop('id',axis=1)\n",
918
+ "X_test = X_test.drop('event',axis=1)\n",
919
+ "# X_test = X_test.drop('id',axis=1)\n",
920
+ "\n",
921
+ "X_test"
922
+ ]
923
+ },
924
+ {
925
+ "cell_type": "code",
926
+ "execution_count": 17,
927
+ "id": "46a97231",
928
+ "metadata": {
929
+ "execution": {
930
+ "iopub.execute_input": "2022-09-05T14:45:06.371027Z",
931
+ "iopub.status.busy": "2022-09-05T14:45:06.369496Z",
932
+ "iopub.status.idle": "2022-09-05T14:45:10.555987Z",
933
+ "shell.execute_reply": "2022-09-05T14:45:10.554972Z"
934
+ },
935
+ "papermill": {
936
+ "duration": 4.195275,
937
+ "end_time": "2022-09-05T14:45:10.558397",
938
+ "exception": false,
939
+ "start_time": "2022-09-05T14:45:06.363122",
940
+ "status": "completed"
941
+ },
942
+ "tags": []
943
+ },
944
+ "outputs": [],
945
+ "source": [
946
+ "X_test['px12'] = X_test.px1 * X_test.px2\n",
947
+ "X_test['Q12'] = X_test.q1 * X_test.q2\n",
948
+ "X_test['phi12'] = X_test.phi1 * X_test.phi2\n",
949
+ "X_test['eta12'] = X_test.eta1 * X_test.eta2\n",
950
+ "X_test['pt12'] = X_test.pt1 * X_test.pt2\n",
951
+ "X_test['E12'] = X_test.e1 * X_test.e2\n",
952
+ "X_test['pz_diff'] = X_test.pz1 - X_test.pz2\n",
953
+ "X_test['eta_diff'] = X_test.eta1 - X_test.eta2\n",
954
+ "y_pred=catbr.predict(X_test)\n",
955
+ "sample_submission=pd.read_csv(\"../input/nsutai/sample_submission.csv\")\n",
956
+ "sample_submission['mass']=y_pred\n",
957
+ "sample_submission.to_csv('submission.csv',index=False)"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": null,
963
+ "id": "b5dfc088",
964
+ "metadata": {
965
+ "papermill": {
966
+ "duration": 0.005294,
967
+ "end_time": "2022-09-05T14:45:10.569668",
968
+ "exception": false,
969
+ "start_time": "2022-09-05T14:45:10.564374",
970
+ "status": "completed"
971
+ },
972
+ "tags": []
973
+ },
974
+ "outputs": [],
975
+ "source": []
976
+ }
977
+ ],
978
+ "metadata": {
979
+ "kernelspec": {
980
+ "display_name": "Python 3",
981
+ "language": "python",
982
+ "name": "python3"
983
+ },
984
+ "language_info": {
985
+ "codemirror_mode": {
986
+ "name": "ipython",
987
+ "version": 3
988
+ },
989
+ "file_extension": ".py",
990
+ "mimetype": "text/x-python",
991
+ "name": "python",
992
+ "nbconvert_exporter": "python",
993
+ "pygments_lexer": "ipython3",
994
+ "version": "3.7.12"
995
+ },
996
+ "papermill": {
997
+ "default_parameters": {},
998
+ "duration": 6189.341373,
999
+ "end_time": "2022-09-05T14:45:11.404111",
1000
+ "environment_variables": {},
1001
+ "exception": null,
1002
+ "input_path": "__notebook__.ipynb",
1003
+ "output_path": "__notebook__.ipynb",
1004
+ "parameters": {},
1005
+ "start_time": "2022-09-05T13:02:02.062738",
1006
+ "version": "2.3.4"
1007
+ }
1008
+ },
1009
+ "nbformat": 4,
1010
+ "nbformat_minor": 5
1011
+ }