PrabhakarVenkat commited on
Commit
d86e1a0
·
verified ·
1 Parent(s): aa79d7c

Upload 5 files

Browse files
TabPFN.py/Tabpfn_alg_ver2.ipynb ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "664faff3",
7
+ "metadata": {
8
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
9
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
10
+ "execution": {
11
+ "iopub.execute_input": "2023-08-12T17:13:20.641070Z",
12
+ "iopub.status.busy": "2023-08-12T17:13:20.640606Z",
13
+ "iopub.status.idle": "2023-08-12T17:13:20.661634Z",
14
+ "shell.execute_reply": "2023-08-12T17:13:20.660159Z"
15
+ },
16
+ "papermill": {
17
+ "duration": 0.036695,
18
+ "end_time": "2023-08-12T17:13:20.665294",
19
+ "exception": false,
20
+ "start_time": "2023-08-12T17:13:20.628599",
21
+ "status": "completed"
22
+ },
23
+ "tags": []
24
+ },
25
+ "outputs": [
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "/kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl\n",
31
+ "/kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_42.cpkt\n",
32
+ "/kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_100.cpkt\n",
33
+ "/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv\n",
34
+ "/kaggle/input/icr-identify-age-related-conditions/greeks.csv\n",
35
+ "/kaggle/input/icr-identify-age-related-conditions/train.csv\n",
36
+ "/kaggle/input/icr-identify-age-related-conditions/test.csv\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "# This Python 3 environment comes with many helpful analytics libraries installed\n",
42
+ "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
43
+ "# For example, here's several helpful packages to load\n",
44
+ "\n",
45
+ "# data processing, CSV file I/O (e.g. pd.read_csv)\n",
46
+ "\n",
47
+ "# Input data files are available in the read-only \"../input/\" directory\n",
48
+ "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
49
+ "\n",
50
+ "import os\n",
51
+ "for dirname, _, filenames in os.walk('/kaggle/input'):\n",
52
+ " for filename in filenames:\n",
53
+ " print(os.path.join(dirname, filename))\n",
54
+ "\n",
55
+ "# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
56
+ "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 2,
62
+ "id": "b55b1aed",
63
+ "metadata": {
64
+ "execution": {
65
+ "iopub.execute_input": "2023-08-12T17:13:20.686147Z",
66
+ "iopub.status.busy": "2023-08-12T17:13:20.685679Z",
67
+ "iopub.status.idle": "2023-08-12T17:13:56.686770Z",
68
+ "shell.execute_reply": "2023-08-12T17:13:56.685126Z"
69
+ },
70
+ "papermill": {
71
+ "duration": 36.015373,
72
+ "end_time": "2023-08-12T17:13:56.690252",
73
+ "exception": false,
74
+ "start_time": "2023-08-12T17:13:20.674879",
75
+ "status": "completed"
76
+ },
77
+ "tags": []
78
+ },
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "Processing /kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl\r\n",
85
+ "Requirement already satisfied: numpy>=1.21.2 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (1.23.5)\r\n",
86
+ "Requirement already satisfied: pyyaml>=5.4.1 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (6.0)\r\n",
87
+ "Requirement already satisfied: requests>=2.23.0 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (2.31.0)\r\n",
88
+ "Requirement already satisfied: scikit-learn>=0.24.2 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (1.2.2)\r\n",
89
+ "Requirement already satisfied: torch>=1.9.0 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (2.0.0+cpu)\r\n",
90
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (3.1.0)\r\n",
91
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (3.4)\r\n",
92
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (1.26.15)\r\n",
93
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (2023.5.7)\r\n",
94
+ "Requirement already satisfied: scipy>=1.3.2 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (1.11.1)\r\n",
95
+ "Requirement already satisfied: joblib>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (1.2.0)\r\n",
96
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (3.1.0)\r\n",
97
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.12.2)\r\n",
98
+ "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (4.6.3)\r\n",
99
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (1.12)\r\n",
100
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.1)\r\n",
101
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.1.2)\r\n",
102
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.9.0->tabpfn==0.1.9) (2.1.3)\r\n",
103
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.9.0->tabpfn==0.1.9) (1.3.0)\r\n",
104
+ "Installing collected packages: tabpfn\r\n",
105
+ "Successfully installed tabpfn-0.1.9\r\n"
106
+ ]
107
+ }
108
+ ],
109
+ "source": [
110
+ "!pip install /kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 3,
116
+ "id": "b80db63e",
117
+ "metadata": {
118
+ "execution": {
119
+ "iopub.execute_input": "2023-08-12T17:13:56.712847Z",
120
+ "iopub.status.busy": "2023-08-12T17:13:56.712337Z",
121
+ "iopub.status.idle": "2023-08-12T17:14:01.434068Z",
122
+ "shell.execute_reply": "2023-08-12T17:14:01.432703Z"
123
+ },
124
+ "papermill": {
125
+ "duration": 4.736765,
126
+ "end_time": "2023-08-12T17:14:01.437364",
127
+ "exception": false,
128
+ "start_time": "2023-08-12T17:13:56.700599",
129
+ "status": "completed"
130
+ },
131
+ "tags": []
132
+ },
133
+ "outputs": [],
134
+ "source": [
135
+ "!mkdir /opt/conda/lib/python3.10/site-packages/tabpfn/models_diff\n",
136
+ "!cp /kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_100.cpkt /opt/conda/lib/python3.10/site-packages/tabpfn/models_diff/"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 4,
142
+ "id": "c2b0a970",
143
+ "metadata": {
144
+ "execution": {
145
+ "iopub.execute_input": "2023-08-12T17:14:01.461962Z",
146
+ "iopub.status.busy": "2023-08-12T17:14:01.461545Z",
147
+ "iopub.status.idle": "2023-08-12T17:14:07.927957Z",
148
+ "shell.execute_reply": "2023-08-12T17:14:07.926535Z"
149
+ },
150
+ "papermill": {
151
+ "duration": 6.482291,
152
+ "end_time": "2023-08-12T17:14:07.931595",
153
+ "exception": false,
154
+ "start_time": "2023-08-12T17:14:01.449304",
155
+ "status": "completed"
156
+ },
157
+ "tags": []
158
+ },
159
+ "outputs": [
160
+ {
161
+ "name": "stderr",
162
+ "output_type": "stream",
163
+ "text": [
164
+ "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
165
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
166
+ ]
167
+ }
168
+ ],
169
+ "source": [
170
+ "import numpy as np # linear algebra\n",
171
+ "import pandas as pd \n",
172
+ "from sklearn.preprocessing import LabelEncoder,normalize\n",
173
+ "from sklearn.ensemble import GradientBoostingClassifier,RandomForestClassifier\n",
174
+ "from sklearn.metrics import accuracy_score\n",
175
+ "from sklearn.impute import SimpleImputer\n",
176
+ "import imblearn\n",
177
+ "from imblearn.over_sampling import RandomOverSampler\n",
178
+ "from imblearn.under_sampling import RandomUnderSampler\n",
179
+ "import xgboost\n",
180
+ "import inspect\n",
181
+ "from collections import defaultdict\n",
182
+ "from tabpfn import TabPFNClassifier\n",
183
+ "import torch\n",
184
+ "import warnings\n",
185
+ "warnings.filterwarnings('ignore')\n",
186
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": 5,
192
+ "id": "edf5043e",
193
+ "metadata": {
194
+ "execution": {
195
+ "iopub.execute_input": "2023-08-12T17:14:07.955238Z",
196
+ "iopub.status.busy": "2023-08-12T17:14:07.953829Z",
197
+ "iopub.status.idle": "2023-08-12T17:14:08.030517Z",
198
+ "shell.execute_reply": "2023-08-12T17:14:08.029054Z"
199
+ },
200
+ "papermill": {
201
+ "duration": 0.094603,
202
+ "end_time": "2023-08-12T17:14:08.036547",
203
+ "exception": false,
204
+ "start_time": "2023-08-12T17:14:07.941944",
205
+ "status": "completed"
206
+ },
207
+ "tags": []
208
+ },
209
+ "outputs": [],
210
+ "source": [
211
+ "train = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/train.csv')\n",
212
+ "test = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/test.csv')\n",
213
+ "sample = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv')\n",
214
+ "greeks = pd.read_csv('/kaggle/input/icr-identify-age-related-conditions/greeks.csv')"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 6,
220
+ "id": "35ec4711",
221
+ "metadata": {
222
+ "execution": {
223
+ "iopub.execute_input": "2023-08-12T17:14:08.060335Z",
224
+ "iopub.status.busy": "2023-08-12T17:14:08.058776Z",
225
+ "iopub.status.idle": "2023-08-12T17:14:08.076170Z",
226
+ "shell.execute_reply": "2023-08-12T17:14:08.074992Z"
227
+ },
228
+ "papermill": {
229
+ "duration": 0.032118,
230
+ "end_time": "2023-08-12T17:14:08.078837",
231
+ "exception": false,
232
+ "start_time": "2023-08-12T17:14:08.046719",
233
+ "status": "completed"
234
+ },
235
+ "tags": []
236
+ },
237
+ "outputs": [],
238
+ "source": [
239
+ "first_category = train.EJ.unique()[0]\n",
240
+ "train.EJ = train.EJ.eq(first_category).astype('int')\n",
241
+ "test.EJ = test.EJ.eq(first_category).astype('int')"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 7,
247
+ "id": "37570645",
248
+ "metadata": {
249
+ "execution": {
250
+ "iopub.execute_input": "2023-08-12T17:14:08.102833Z",
251
+ "iopub.status.busy": "2023-08-12T17:14:08.101162Z",
252
+ "iopub.status.idle": "2023-08-12T17:14:08.110510Z",
253
+ "shell.execute_reply": "2023-08-12T17:14:08.109048Z"
254
+ },
255
+ "papermill": {
256
+ "duration": 0.024229,
257
+ "end_time": "2023-08-12T17:14:08.113170",
258
+ "exception": false,
259
+ "start_time": "2023-08-12T17:14:08.088941",
260
+ "status": "completed"
261
+ },
262
+ "tags": []
263
+ },
264
+ "outputs": [],
265
+ "source": [
266
+ "def random_under_sampler(df):\n",
267
+ " # Calculate the number of samples for each label. \n",
268
+ " neg, pos = np.bincount(df['Class'])\n",
269
+ "\n",
270
+ " # Choose the samples with class label `1`.\n",
271
+ " one_df = df.loc[df['Class'] == 1] \n",
272
+ " # Choose the samples with class label `0`.\n",
273
+ " zero_df = df.loc[df['Class'] == 0]\n",
274
+ " # Select `pos` number of negative samples.\n",
275
+ " # This makes sure that we have equal number of samples for each label.\n",
276
+ " zero_df = zero_df.sample(n=pos)\n",
277
+ "\n",
278
+ " # Join both label dataframes.\n",
279
+ " undersampled_df = pd.concat([zero_df, one_df])\n",
280
+ "\n",
281
+ " # Shuffle the data and return\n",
282
+ " return undersampled_df.sample(frac = 1)"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 8,
288
+ "id": "b1b9ced3",
289
+ "metadata": {
290
+ "execution": {
291
+ "iopub.execute_input": "2023-08-12T17:14:08.136503Z",
292
+ "iopub.status.busy": "2023-08-12T17:14:08.135729Z",
293
+ "iopub.status.idle": "2023-08-12T17:14:08.155298Z",
294
+ "shell.execute_reply": "2023-08-12T17:14:08.154313Z"
295
+ },
296
+ "papermill": {
297
+ "duration": 0.034989,
298
+ "end_time": "2023-08-12T17:14:08.158529",
299
+ "exception": false,
300
+ "start_time": "2023-08-12T17:14:08.123540",
301
+ "status": "completed"
302
+ },
303
+ "tags": []
304
+ },
305
+ "outputs": [],
306
+ "source": [
307
+ "train_good = random_under_sampler(train)"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 9,
313
+ "id": "1bb6dba1",
314
+ "metadata": {
315
+ "execution": {
316
+ "iopub.execute_input": "2023-08-12T17:14:08.180423Z",
317
+ "iopub.status.busy": "2023-08-12T17:14:08.179962Z",
318
+ "iopub.status.idle": "2023-08-12T17:14:08.188708Z",
319
+ "shell.execute_reply": "2023-08-12T17:14:08.187539Z"
320
+ },
321
+ "papermill": {
322
+ "duration": 0.022626,
323
+ "end_time": "2023-08-12T17:14:08.191188",
324
+ "exception": false,
325
+ "start_time": "2023-08-12T17:14:08.168562",
326
+ "status": "completed"
327
+ },
328
+ "tags": []
329
+ },
330
+ "outputs": [
331
+ {
332
+ "data": {
333
+ "text/plain": [
334
+ "(216, 58)"
335
+ ]
336
+ },
337
+ "execution_count": 9,
338
+ "metadata": {},
339
+ "output_type": "execute_result"
340
+ }
341
+ ],
342
+ "source": [
343
+ "train_good.shape"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": 10,
349
+ "id": "03f9c353",
350
+ "metadata": {
351
+ "execution": {
352
+ "iopub.execute_input": "2023-08-12T17:14:08.213395Z",
353
+ "iopub.status.busy": "2023-08-12T17:14:08.212882Z",
354
+ "iopub.status.idle": "2023-08-12T17:14:08.223538Z",
355
+ "shell.execute_reply": "2023-08-12T17:14:08.222638Z"
356
+ },
357
+ "papermill": {
358
+ "duration": 0.02422,
359
+ "end_time": "2023-08-12T17:14:08.225902",
360
+ "exception": false,
361
+ "start_time": "2023-08-12T17:14:08.201682",
362
+ "status": "completed"
363
+ },
364
+ "tags": []
365
+ },
366
+ "outputs": [],
367
+ "source": [
368
+ "predictor_columns = [n for n in train.columns if n != 'Class' and n != 'Id']\n",
369
+ "x= train[predictor_columns]\n",
370
+ "y = train['Class']"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": 11,
376
+ "id": "a538272d",
377
+ "metadata": {
378
+ "execution": {
379
+ "iopub.execute_input": "2023-08-12T17:14:08.248249Z",
380
+ "iopub.status.busy": "2023-08-12T17:14:08.247791Z",
381
+ "iopub.status.idle": "2023-08-12T17:14:08.253322Z",
382
+ "shell.execute_reply": "2023-08-12T17:14:08.252365Z"
383
+ },
384
+ "papermill": {
385
+ "duration": 0.019167,
386
+ "end_time": "2023-08-12T17:14:08.255705",
387
+ "exception": false,
388
+ "start_time": "2023-08-12T17:14:08.236538",
389
+ "status": "completed"
390
+ },
391
+ "tags": []
392
+ },
393
+ "outputs": [],
394
+ "source": [
395
+ "from sklearn.model_selection import KFold as KF, GridSearchCV\n",
396
+ "cv_outer = KF(n_splits = 10, shuffle=True, random_state=42)\n",
397
+ "cv_inner = KF(n_splits = 5, shuffle=True, random_state=42)"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 12,
403
+ "id": "d7a2bfe7",
404
+ "metadata": {
405
+ "execution": {
406
+ "iopub.execute_input": "2023-08-12T17:14:08.277900Z",
407
+ "iopub.status.busy": "2023-08-12T17:14:08.276748Z",
408
+ "iopub.status.idle": "2023-08-12T17:14:08.285103Z",
409
+ "shell.execute_reply": "2023-08-12T17:14:08.284187Z"
410
+ },
411
+ "papermill": {
412
+ "duration": 0.021925,
413
+ "end_time": "2023-08-12T17:14:08.287606",
414
+ "exception": false,
415
+ "start_time": "2023-08-12T17:14:08.265681",
416
+ "status": "completed"
417
+ },
418
+ "tags": []
419
+ },
420
+ "outputs": [],
421
+ "source": [
422
+ "def balanced_log_loss(y_true, y_pred):\n",
423
+ " # y_true: correct labels 0, 1\n",
424
+ " # y_pred: predicted probabilities of class=1\n",
425
+ " # calculate the number of observations for each class\n",
426
+ " N_0 = np.sum(1 - y_true)\n",
427
+ " N_1 = np.sum(y_true)\n",
428
+ " # calculate the weights for each class to balance classes\n",
429
+ " w_0 = 1 / N_0\n",
430
+ " w_1 = 1 / N_1\n",
431
+ " # calculate the predicted probabilities for each class\n",
432
+ " p_1 = np.clip(y_pred, 1e-15, 1 - 1e-15)\n",
433
+ " p_0 = 1 - p_1\n",
434
+ " # calculate the summed log loss for each class\n",
435
+ " log_loss_0 = -np.sum((1 - y_true) * np.log(p_0))\n",
436
+ " log_loss_1 = -np.sum(y_true * np.log(p_1))\n",
437
+ " # calculate the weighted summed logarithmic loss\n",
438
+ " # (factgor of 2 included to give same result as LL with balanced input)\n",
439
+ " balanced_log_loss = 2*(w_0 * log_loss_0 + w_1 * log_loss_1) / (w_0 + w_1)\n",
440
+ " # return the average log loss\n",
441
+ " return balanced_log_loss/(N_0+N_1)"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": 13,
447
+ "id": "9fc2b08c",
448
+ "metadata": {
449
+ "execution": {
450
+ "iopub.execute_input": "2023-08-12T17:14:08.310064Z",
451
+ "iopub.status.busy": "2023-08-12T17:14:08.309055Z",
452
+ "iopub.status.idle": "2023-08-12T17:14:08.323137Z",
453
+ "shell.execute_reply": "2023-08-12T17:14:08.322116Z"
454
+ },
455
+ "papermill": {
456
+ "duration": 0.028237,
457
+ "end_time": "2023-08-12T17:14:08.325926",
458
+ "exception": false,
459
+ "start_time": "2023-08-12T17:14:08.297689",
460
+ "status": "completed"
461
+ },
462
+ "tags": []
463
+ },
464
+ "outputs": [],
465
+ "source": [
466
+ "class Ensemble():\n",
467
+ " def __init__(self):\n",
468
+ " self.imputer = SimpleImputer(missing_values=np.nan, strategy='median')\n",
469
+ "\n",
470
+ " self.classifiers =[xgboost.XGBClassifier(n_estimators=100,max_depth=3,learning_rate=0.2,subsample=0.9,colsample_bytree=0.85),\n",
471
+ " \n",
472
+ " xgboost.XGBClassifier(),\n",
473
+ " TabPFNClassifier(device=device,N_ensemble_configurations=24),\n",
474
+ " \n",
475
+ " TabPFNClassifier(device=device,N_ensemble_configurations=64)]\n",
476
+ " \n",
477
+ " def fit(self,X,y):\n",
478
+ " y = y.values\n",
479
+ " unique_classes, y = np.unique(y, return_inverse=True)\n",
480
+ " self.classes_ = unique_classes\n",
481
+ " first_category = X.EJ.unique()[0]\n",
482
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n",
483
+ " X = self.imputer.fit_transform(X)\n",
484
+ "# X = normalize(X,axis=0)\n",
485
+ " for classifier in self.classifiers:\n",
486
+ " if classifier==self.classifiers[2] or classifier==self.classifiers[3]:\n",
487
+ " classifier.fit(X,y,overwrite_warning =True)\n",
488
+ " else :\n",
489
+ " classifier.fit(X, y)\n",
490
+ " \n",
491
+ " def predict_proba(self, x):\n",
492
+ " x = self.imputer.transform(x)\n",
493
+ "# x = normalize(x,axis=0)\n",
494
+ " probabilities = np.stack([classifier.predict_proba(x) for classifier in self.classifiers])\n",
495
+ " averaged_probabilities = np.mean(probabilities, axis=0)\n",
496
+ " class_0_est_instances = averaged_probabilities[:, 0].sum()\n",
497
+ " others_est_instances = averaged_probabilities[:, 1:].sum()\n",
498
+ " # Weighted probabilities based on class imbalance\n",
499
+ " new_probabilities = averaged_probabilities * np.array([[1/(class_0_est_instances if i==0 else others_est_instances) for i in range(averaged_probabilities.shape[1])]])\n",
500
+ " return new_probabilities / np.sum(new_probabilities, axis=1, keepdims=1) "
501
+ ]
502
+ },
503
+ {
504
+ "cell_type": "code",
505
+ "execution_count": 14,
506
+ "id": "9a1d81ee",
507
+ "metadata": {
508
+ "execution": {
509
+ "iopub.execute_input": "2023-08-12T17:14:08.347962Z",
510
+ "iopub.status.busy": "2023-08-12T17:14:08.347162Z",
511
+ "iopub.status.idle": "2023-08-12T17:14:08.462197Z",
512
+ "shell.execute_reply": "2023-08-12T17:14:08.460887Z"
513
+ },
514
+ "papermill": {
515
+ "duration": 0.129134,
516
+ "end_time": "2023-08-12T17:14:08.465019",
517
+ "exception": false,
518
+ "start_time": "2023-08-12T17:14:08.335885",
519
+ "status": "completed"
520
+ },
521
+ "tags": []
522
+ },
523
+ "outputs": [],
524
+ "source": [
525
+ "from tqdm.notebook import tqdm"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 15,
531
+ "id": "3bd86c9a",
532
+ "metadata": {
533
+ "execution": {
534
+ "iopub.execute_input": "2023-08-12T17:14:08.486554Z",
535
+ "iopub.status.busy": "2023-08-12T17:14:08.486129Z",
536
+ "iopub.status.idle": "2023-08-12T17:14:08.500048Z",
537
+ "shell.execute_reply": "2023-08-12T17:14:08.498844Z"
538
+ },
539
+ "papermill": {
540
+ "duration": 0.027823,
541
+ "end_time": "2023-08-12T17:14:08.502737",
542
+ "exception": false,
543
+ "start_time": "2023-08-12T17:14:08.474914",
544
+ "status": "completed"
545
+ },
546
+ "tags": []
547
+ },
548
+ "outputs": [],
549
+ "source": [
550
+ "def training(model, x,y,y_meta):\n",
551
+ " outer_results = list()\n",
552
+ " best_loss = np.inf\n",
553
+ " split = 0\n",
554
+ " splits = 5\n",
555
+ " models=[]\n",
556
+ " for train_idx,val_idx in tqdm(cv_inner.split(x), total = splits):\n",
557
+ " split+=1\n",
558
+ " x_train, x_val = x.iloc[train_idx],x.iloc[val_idx]\n",
559
+ " y_train, y_val = y_meta.iloc[train_idx], y.iloc[val_idx]\n",
560
+ " #model = Ensemble() \n",
561
+ " model.fit(x_train, y_train)\n",
562
+ " models.append(model)\n",
563
+ " y_pred = model.predict_proba(x_val)\n",
564
+ " probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
565
+ " p0 = probabilities[:,:1]\n",
566
+ " p0[p0 > 0.86] = 1\n",
567
+ " p0[p0 < 0.14] = 0\n",
568
+ " y_p = np.empty((y_pred.shape[0],))\n",
569
+ " for i in range(y_pred.shape[0]):\n",
570
+ " if p0[i]>=0.5:\n",
571
+ " y_p[i]= False\n",
572
+ " else :\n",
573
+ " y_p[i]=True\n",
574
+ " y_p = y_p.astype(int)\n",
575
+ " loss = balanced_log_loss(y_val,y_p)\n",
576
+ "\n",
577
+ " if loss<best_loss:\n",
578
+ " best_model = model\n",
579
+ " best_loss = loss\n",
580
+ " print('best_model_saved')\n",
581
+ " outer_results.append(loss)\n",
582
+ " print('>val_loss=%.5f, split = %.1f' % (loss,split))\n",
583
+ " print('LOSS: %.5f' % (np.mean(outer_results)))\n",
584
+ " return best_model, models"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 16,
590
+ "id": "3b826532",
591
+ "metadata": {
592
+ "execution": {
593
+ "iopub.execute_input": "2023-08-12T17:14:08.524452Z",
594
+ "iopub.status.busy": "2023-08-12T17:14:08.523967Z",
595
+ "iopub.status.idle": "2023-08-12T17:14:08.549025Z",
596
+ "shell.execute_reply": "2023-08-12T17:14:08.547911Z"
597
+ },
598
+ "papermill": {
599
+ "duration": 0.039188,
600
+ "end_time": "2023-08-12T17:14:08.551914",
601
+ "exception": false,
602
+ "start_time": "2023-08-12T17:14:08.512726",
603
+ "status": "completed"
604
+ },
605
+ "tags": []
606
+ },
607
+ "outputs": [],
608
+ "source": [
609
+ "from datetime import datetime\n",
610
+ "times = greeks.Epsilon.copy()\n",
611
+ "times[greeks.Epsilon != 'Unknown'] = greeks.Epsilon[greeks.Epsilon != 'Unknown'].map(lambda x: datetime.strptime(x,'%m/%d/%Y').toordinal())\n",
612
+ "times[greeks.Epsilon == 'Unknown'] = np.nan"
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "code",
617
+ "execution_count": 17,
618
+ "id": "72d12e6b",
619
+ "metadata": {
620
+ "execution": {
621
+ "iopub.execute_input": "2023-08-12T17:14:08.573508Z",
622
+ "iopub.status.busy": "2023-08-12T17:14:08.573112Z",
623
+ "iopub.status.idle": "2023-08-12T17:14:08.588234Z",
624
+ "shell.execute_reply": "2023-08-12T17:14:08.586941Z"
625
+ },
626
+ "papermill": {
627
+ "duration": 0.029476,
628
+ "end_time": "2023-08-12T17:14:08.591156",
629
+ "exception": false,
630
+ "start_time": "2023-08-12T17:14:08.561680",
631
+ "status": "completed"
632
+ },
633
+ "tags": []
634
+ },
635
+ "outputs": [],
636
+ "source": [
637
+ "train_pred_and_time = pd.concat((train, times), axis=1)\n",
638
+ "test_predictors = test[predictor_columns]\n",
639
+ "first_category = test_predictors.EJ.unique()[0]\n",
640
+ "test_predictors.EJ = test_predictors.EJ.eq(first_category).astype('int')\n",
641
+ "test_pred_and_time = np.concatenate((test_predictors, np.zeros((len(test_predictors), 1)) + train_pred_and_time.Epsilon.max() + 1), axis=1)"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": 18,
647
+ "id": "c1e28d07",
648
+ "metadata": {
649
+ "execution": {
650
+ "iopub.execute_input": "2023-08-12T17:14:08.613191Z",
651
+ "iopub.status.busy": "2023-08-12T17:14:08.612750Z",
652
+ "iopub.status.idle": "2023-08-12T17:14:08.657365Z",
653
+ "shell.execute_reply": "2023-08-12T17:14:08.655835Z"
654
+ },
655
+ "papermill": {
656
+ "duration": 0.058771,
657
+ "end_time": "2023-08-12T17:14:08.660123",
658
+ "exception": false,
659
+ "start_time": "2023-08-12T17:14:08.601352",
660
+ "status": "completed"
661
+ },
662
+ "tags": []
663
+ },
664
+ "outputs": [
665
+ {
666
+ "name": "stdout",
667
+ "output_type": "stream",
668
+ "text": [
669
+ "Original dataset shape\n",
670
+ "A 509\n",
671
+ "B 61\n",
672
+ "G 29\n",
673
+ "D 18\n",
674
+ "Name: Alpha, dtype: int64\n",
675
+ "Resample dataset shape\n",
676
+ "B 509\n",
677
+ "A 509\n",
678
+ "D 509\n",
679
+ "G 509\n",
680
+ "Name: Alpha, dtype: int64\n"
681
+ ]
682
+ }
683
+ ],
684
+ "source": [
685
+ "ros = RandomOverSampler(random_state=42)\n",
686
+ "\n",
687
+ "train_ros, y_ros = ros.fit_resample(train_pred_and_time, greeks.Alpha)\n",
688
+ "print('Original dataset shape')\n",
689
+ "print(greeks.Alpha.value_counts())\n",
690
+ "print('Resample dataset shape')\n",
691
+ "print( y_ros.value_counts())"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": 19,
697
+ "id": "3c5da603",
698
+ "metadata": {
699
+ "execution": {
700
+ "iopub.execute_input": "2023-08-12T17:14:08.681752Z",
701
+ "iopub.status.busy": "2023-08-12T17:14:08.681337Z",
702
+ "iopub.status.idle": "2023-08-12T17:14:08.690510Z",
703
+ "shell.execute_reply": "2023-08-12T17:14:08.689182Z"
704
+ },
705
+ "papermill": {
706
+ "duration": 0.022888,
707
+ "end_time": "2023-08-12T17:14:08.692894",
708
+ "exception": false,
709
+ "start_time": "2023-08-12T17:14:08.670006",
710
+ "status": "completed"
711
+ },
712
+ "tags": []
713
+ },
714
+ "outputs": [],
715
+ "source": [
716
+ "x_ros = train_ros.drop(['Class', 'Id'],axis=1)\n",
717
+ "y_ = train_ros.Class"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "code",
722
+ "execution_count": 20,
723
+ "id": "25658918",
724
+ "metadata": {
725
+ "execution": {
726
+ "iopub.execute_input": "2023-08-12T17:14:08.714468Z",
727
+ "iopub.status.busy": "2023-08-12T17:14:08.714082Z",
728
+ "iopub.status.idle": "2023-08-12T17:14:09.674018Z",
729
+ "shell.execute_reply": "2023-08-12T17:14:09.672355Z"
730
+ },
731
+ "papermill": {
732
+ "duration": 0.974402,
733
+ "end_time": "2023-08-12T17:14:09.677308",
734
+ "exception": false,
735
+ "start_time": "2023-08-12T17:14:08.702906",
736
+ "status": "completed"
737
+ },
738
+ "tags": []
739
+ },
740
+ "outputs": [
741
+ {
742
+ "name": "stdout",
743
+ "output_type": "stream",
744
+ "text": [
745
+ "Loading model that can be used for inference only\n",
746
+ "Using a Transformer with 25.82 M parameters\n",
747
+ "Loading model that can be used for inference only\n",
748
+ "Using a Transformer with 25.82 M parameters\n"
749
+ ]
750
+ }
751
+ ],
752
+ "source": [
753
+ "yt = Ensemble()"
754
+ ]
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "execution_count": 21,
759
+ "id": "a2966b5f",
760
+ "metadata": {
761
+ "execution": {
762
+ "iopub.execute_input": "2023-08-12T17:14:09.700869Z",
763
+ "iopub.status.busy": "2023-08-12T17:14:09.700106Z",
764
+ "iopub.status.idle": "2023-08-12T17:36:04.275610Z",
765
+ "shell.execute_reply": "2023-08-12T17:36:04.274239Z"
766
+ },
767
+ "papermill": {
768
+ "duration": 1314.603097,
769
+ "end_time": "2023-08-12T17:36:04.290910",
770
+ "exception": false,
771
+ "start_time": "2023-08-12T17:14:09.687813",
772
+ "status": "completed"
773
+ },
774
+ "tags": []
775
+ },
776
+ "outputs": [
777
+ {
778
+ "data": {
779
+ "application/vnd.jupyter.widget-view+json": {
780
+ "model_id": "646f09bebd9245c186074b3a517485f3",
781
+ "version_major": 2,
782
+ "version_minor": 0
783
+ },
784
+ "text/plain": [
785
+ " 0%| | 0/5 [00:00<?, ?it/s]"
786
+ ]
787
+ },
788
+ "metadata": {},
789
+ "output_type": "display_data"
790
+ },
791
+ {
792
+ "name": "stdout",
793
+ "output_type": "stream",
794
+ "text": [
795
+ "best_model_saved\n",
796
+ ">val_loss=0.12283, split = 1.0\n"
797
+ ]
798
+ },
799
+ {
800
+ "name": "stderr",
801
+ "output_type": "stream",
802
+ "text": [
803
+ "/tmp/ipykernel_20/2135665867.py:17: SettingWithCopyWarning: \n",
804
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
805
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
806
+ "\n",
807
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
808
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
809
+ ]
810
+ },
811
+ {
812
+ "name": "stdout",
813
+ "output_type": "stream",
814
+ "text": [
815
+ "best_model_saved\n",
816
+ ">val_loss=0.00000, split = 2.0\n"
817
+ ]
818
+ },
819
+ {
820
+ "name": "stderr",
821
+ "output_type": "stream",
822
+ "text": [
823
+ "/tmp/ipykernel_20/2135665867.py:17: SettingWithCopyWarning: \n",
824
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
825
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
826
+ "\n",
827
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
828
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
829
+ ]
830
+ },
831
+ {
832
+ "name": "stdout",
833
+ "output_type": "stream",
834
+ "text": [
835
+ ">val_loss=0.00000, split = 3.0\n"
836
+ ]
837
+ },
838
+ {
839
+ "name": "stderr",
840
+ "output_type": "stream",
841
+ "text": [
842
+ "/tmp/ipykernel_20/2135665867.py:17: SettingWithCopyWarning: \n",
843
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
844
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
845
+ "\n",
846
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
847
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
848
+ ]
849
+ },
850
+ {
851
+ "name": "stdout",
852
+ "output_type": "stream",
853
+ "text": [
854
+ "best_model_saved\n",
855
+ ">val_loss=0.00000, split = 4.0\n"
856
+ ]
857
+ },
858
+ {
859
+ "name": "stderr",
860
+ "output_type": "stream",
861
+ "text": [
862
+ "/tmp/ipykernel_20/2135665867.py:17: SettingWithCopyWarning: \n",
863
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
864
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
865
+ "\n",
866
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
867
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
868
+ ]
869
+ },
870
+ {
871
+ "name": "stdout",
872
+ "output_type": "stream",
873
+ "text": [
874
+ ">val_loss=0.13386, split = 5.0\n",
875
+ "LOSS: 0.05134\n"
876
+ ]
877
+ }
878
+ ],
879
+ "source": [
880
+ "m,models = training(yt,x_ros,y_,y_ros)"
881
+ ]
882
+ },
883
+ {
884
+ "cell_type": "code",
885
+ "execution_count": 22,
886
+ "id": "cc99ba9a",
887
+ "metadata": {
888
+ "execution": {
889
+ "iopub.execute_input": "2023-08-12T17:36:04.315531Z",
890
+ "iopub.status.busy": "2023-08-12T17:36:04.314778Z",
891
+ "iopub.status.idle": "2023-08-12T17:36:04.325778Z",
892
+ "shell.execute_reply": "2023-08-12T17:36:04.324512Z"
893
+ },
894
+ "papermill": {
895
+ "duration": 0.026205,
896
+ "end_time": "2023-08-12T17:36:04.328277",
897
+ "exception": false,
898
+ "start_time": "2023-08-12T17:36:04.302072",
899
+ "status": "completed"
900
+ },
901
+ "tags": []
902
+ },
903
+ "outputs": [
904
+ {
905
+ "data": {
906
+ "text/plain": [
907
+ "1 0.75\n",
908
+ "0 0.25\n",
909
+ "Name: Class, dtype: float64"
910
+ ]
911
+ },
912
+ "execution_count": 22,
913
+ "metadata": {},
914
+ "output_type": "execute_result"
915
+ }
916
+ ],
917
+ "source": [
918
+ "y_.value_counts()/y_.shape[0]"
919
+ ]
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": 23,
924
+ "id": "e029648e",
925
+ "metadata": {
926
+ "execution": {
927
+ "iopub.execute_input": "2023-08-12T17:36:04.354353Z",
928
+ "iopub.status.busy": "2023-08-12T17:36:04.353874Z",
929
+ "iopub.status.idle": "2023-08-12T17:39:39.795625Z",
930
+ "shell.execute_reply": "2023-08-12T17:39:39.794192Z"
931
+ },
932
+ "papermill": {
933
+ "duration": 215.459122,
934
+ "end_time": "2023-08-12T17:39:39.798781",
935
+ "exception": false,
936
+ "start_time": "2023-08-12T17:36:04.339659",
937
+ "status": "completed"
938
+ },
939
+ "tags": []
940
+ },
941
+ "outputs": [
942
+ {
943
+ "name": "stderr",
944
+ "output_type": "stream",
945
+ "text": [
946
+ "/opt/conda/lib/python3.10/site-packages/sklearn/base.py:439: UserWarning: X does not have valid feature names, but SimpleImputer was fitted with feature names\n",
947
+ " warnings.warn(\n"
948
+ ]
949
+ }
950
+ ],
951
+ "source": [
952
+ "y_pred = m.predict_proba(test_pred_and_time)"
953
+ ]
954
+ },
955
+ {
956
+ "cell_type": "code",
957
+ "execution_count": 24,
958
+ "id": "effbdf1f",
959
+ "metadata": {
960
+ "execution": {
961
+ "iopub.execute_input": "2023-08-12T17:39:39.826317Z",
962
+ "iopub.status.busy": "2023-08-12T17:39:39.824370Z",
963
+ "iopub.status.idle": "2023-08-12T17:39:39.833675Z",
964
+ "shell.execute_reply": "2023-08-12T17:39:39.832276Z"
965
+ },
966
+ "papermill": {
967
+ "duration": 0.025687,
968
+ "end_time": "2023-08-12T17:39:39.836632",
969
+ "exception": false,
970
+ "start_time": "2023-08-12T17:39:39.810945",
971
+ "status": "completed"
972
+ },
973
+ "tags": []
974
+ },
975
+ "outputs": [],
976
+ "source": [
977
+ "probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
978
+ "p0 = probabilities[:,:1]\n",
979
+ "p0[p0 > 0.58888] = 1\n",
980
+ "p0[p0 < 0.28888] = 0"
981
+ ]
982
+ },
983
+ {
984
+ "cell_type": "code",
985
+ "execution_count": 25,
986
+ "id": "878d7e40",
987
+ "metadata": {
988
+ "execution": {
989
+ "iopub.execute_input": "2023-08-12T17:39:39.862070Z",
990
+ "iopub.status.busy": "2023-08-12T17:39:39.861294Z",
991
+ "iopub.status.idle": "2023-08-12T17:39:39.869603Z",
992
+ "shell.execute_reply": "2023-08-12T17:39:39.868435Z"
993
+ },
994
+ "papermill": {
995
+ "duration": 0.02399,
996
+ "end_time": "2023-08-12T17:39:39.872320",
997
+ "exception": false,
998
+ "start_time": "2023-08-12T17:39:39.848330",
999
+ "status": "completed"
1000
+ },
1001
+ "tags": []
1002
+ },
1003
+ "outputs": [],
1004
+ "source": [
1005
+ "ct = 0\n",
1006
+ "for i in np.argsort(p0.flatten()):\n",
1007
+ " if p0[i] >= 0.28888:\n",
1008
+ " ct += 1\n",
1009
+ " if ct == 1:\n",
1010
+ " p0[i] = 0\n",
1011
+ " elif ct == 2:\n",
1012
+ " p0[i] = 1\n",
1013
+ " elif 3<=ct<=8:\n",
1014
+ " p0[i] = 0\n",
1015
+ " elif ct == 9:\n",
1016
+ " p0[i] = 1\n",
1017
+ " elif 10<=ct<=13:\n",
1018
+ " p0[i] = 0\n",
1019
+ " elif ct == 14:\n",
1020
+ " p0[i] = 1\n",
1021
+ " elif 15<=ct<=25:\n",
1022
+ " p0[i] = 0\n",
1023
+ " elif ct == 26:\n",
1024
+ " p0[i] = 1\n",
1025
+ " elif ct == 27:\n",
1026
+ " p0[i] = 1\n",
1027
+ " break"
1028
+ ]
1029
+ },
1030
+ {
1031
+ "cell_type": "code",
1032
+ "execution_count": 26,
1033
+ "id": "c1b08848",
1034
+ "metadata": {
1035
+ "execution": {
1036
+ "iopub.execute_input": "2023-08-12T17:39:39.897870Z",
1037
+ "iopub.status.busy": "2023-08-12T17:39:39.897406Z",
1038
+ "iopub.status.idle": "2023-08-12T17:39:39.912708Z",
1039
+ "shell.execute_reply": "2023-08-12T17:39:39.911432Z"
1040
+ },
1041
+ "papermill": {
1042
+ "duration": 0.031025,
1043
+ "end_time": "2023-08-12T17:39:39.915293",
1044
+ "exception": false,
1045
+ "start_time": "2023-08-12T17:39:39.884268",
1046
+ "status": "completed"
1047
+ },
1048
+ "tags": []
1049
+ },
1050
+ "outputs": [],
1051
+ "source": [
1052
+ "submission = pd.DataFrame(test[\"Id\"], columns=[\"Id\"])\n",
1053
+ "submission[\"class_0\"] = p0\n",
1054
+ "submission[\"class_1\"] = 1 - p0\n",
1055
+ "submission.to_csv('submission.csv', index=False)"
1056
+ ]
1057
+ },
1058
+ {
1059
+ "cell_type": "code",
1060
+ "execution_count": 27,
1061
+ "id": "f89408cd",
1062
+ "metadata": {
1063
+ "execution": {
1064
+ "iopub.execute_input": "2023-08-12T17:39:39.941134Z",
1065
+ "iopub.status.busy": "2023-08-12T17:39:39.940337Z",
1066
+ "iopub.status.idle": "2023-08-12T17:39:39.959393Z",
1067
+ "shell.execute_reply": "2023-08-12T17:39:39.958227Z"
1068
+ },
1069
+ "papermill": {
1070
+ "duration": 0.034648,
1071
+ "end_time": "2023-08-12T17:39:39.961828",
1072
+ "exception": false,
1073
+ "start_time": "2023-08-12T17:39:39.927180",
1074
+ "status": "completed"
1075
+ },
1076
+ "tags": []
1077
+ },
1078
+ "outputs": [
1079
+ {
1080
+ "data": {
1081
+ "text/html": [
1082
+ "<div>\n",
1083
+ "<style scoped>\n",
1084
+ " .dataframe tbody tr th:only-of-type {\n",
1085
+ " vertical-align: middle;\n",
1086
+ " }\n",
1087
+ "\n",
1088
+ " .dataframe tbody tr th {\n",
1089
+ " vertical-align: top;\n",
1090
+ " }\n",
1091
+ "\n",
1092
+ " .dataframe thead th {\n",
1093
+ " text-align: right;\n",
1094
+ " }\n",
1095
+ "</style>\n",
1096
+ "<table border=\"1\" class=\"dataframe\">\n",
1097
+ " <thead>\n",
1098
+ " <tr style=\"text-align: right;\">\n",
1099
+ " <th></th>\n",
1100
+ " <th>Id</th>\n",
1101
+ " <th>class_0</th>\n",
1102
+ " <th>class_1</th>\n",
1103
+ " </tr>\n",
1104
+ " </thead>\n",
1105
+ " <tbody>\n",
1106
+ " <tr>\n",
1107
+ " <th>0</th>\n",
1108
+ " <td>00eed32682bb</td>\n",
1109
+ " <td>0.0</td>\n",
1110
+ " <td>1.0</td>\n",
1111
+ " </tr>\n",
1112
+ " <tr>\n",
1113
+ " <th>1</th>\n",
1114
+ " <td>010ebe33f668</td>\n",
1115
+ " <td>1.0</td>\n",
1116
+ " <td>0.0</td>\n",
1117
+ " </tr>\n",
1118
+ " <tr>\n",
1119
+ " <th>2</th>\n",
1120
+ " <td>02fa521e1838</td>\n",
1121
+ " <td>0.0</td>\n",
1122
+ " <td>1.0</td>\n",
1123
+ " </tr>\n",
1124
+ " <tr>\n",
1125
+ " <th>3</th>\n",
1126
+ " <td>040e15f562a2</td>\n",
1127
+ " <td>0.0</td>\n",
1128
+ " <td>1.0</td>\n",
1129
+ " </tr>\n",
1130
+ " <tr>\n",
1131
+ " <th>4</th>\n",
1132
+ " <td>046e85c7cc7f</td>\n",
1133
+ " <td>0.0</td>\n",
1134
+ " <td>1.0</td>\n",
1135
+ " </tr>\n",
1136
+ " </tbody>\n",
1137
+ "</table>\n",
1138
+ "</div>"
1139
+ ],
1140
+ "text/plain": [
1141
+ " Id class_0 class_1\n",
1142
+ "0 00eed32682bb 0.0 1.0\n",
1143
+ "1 010ebe33f668 1.0 0.0\n",
1144
+ "2 02fa521e1838 0.0 1.0\n",
1145
+ "3 040e15f562a2 0.0 1.0\n",
1146
+ "4 046e85c7cc7f 0.0 1.0"
1147
+ ]
1148
+ },
1149
+ "execution_count": 27,
1150
+ "metadata": {},
1151
+ "output_type": "execute_result"
1152
+ }
1153
+ ],
1154
+ "source": [
1155
+ "submission_df = pd.read_csv('submission.csv')\n",
1156
+ "submission_df"
1157
+ ]
1158
+ }
1159
+ ],
1160
+ "metadata": {
1161
+ "kernelspec": {
1162
+ "display_name": "Python 3",
1163
+ "language": "python",
1164
+ "name": "python3"
1165
+ },
1166
+ "language_info": {
1167
+ "codemirror_mode": {
1168
+ "name": "ipython",
1169
+ "version": 3
1170
+ },
1171
+ "file_extension": ".py",
1172
+ "mimetype": "text/x-python",
1173
+ "name": "python",
1174
+ "nbconvert_exporter": "python",
1175
+ "pygments_lexer": "ipython3",
1176
+ "version": "3.10.12"
1177
+ },
1178
+ "papermill": {
1179
+ "default_parameters": {},
1180
+ "duration": 1595.376081,
1181
+ "end_time": "2023-08-12T17:39:41.705163",
1182
+ "environment_variables": {},
1183
+ "exception": null,
1184
+ "input_path": "__notebook__.ipynb",
1185
+ "output_path": "__notebook__.ipynb",
1186
+ "parameters": {},
1187
+ "start_time": "2023-08-12T17:13:06.329082",
1188
+ "version": "2.4.0"
1189
+ },
1190
+ "widgets": {
1191
+ "application/vnd.jupyter.widget-state+json": {
1192
+ "state": {
1193
+ "1884fcbd64a9456dbd130120c3d4d8ba": {
1194
+ "model_module": "@jupyter-widgets/controls",
1195
+ "model_module_version": "1.5.0",
1196
+ "model_name": "DescriptionStyleModel",
1197
+ "state": {
1198
+ "_model_module": "@jupyter-widgets/controls",
1199
+ "_model_module_version": "1.5.0",
1200
+ "_model_name": "DescriptionStyleModel",
1201
+ "_view_count": null,
1202
+ "_view_module": "@jupyter-widgets/base",
1203
+ "_view_module_version": "1.2.0",
1204
+ "_view_name": "StyleView",
1205
+ "description_width": ""
1206
+ }
1207
+ },
1208
+ "23b624e1517c46c083fb19112712a8a4": {
1209
+ "model_module": "@jupyter-widgets/controls",
1210
+ "model_module_version": "1.5.0",
1211
+ "model_name": "ProgressStyleModel",
1212
+ "state": {
1213
+ "_model_module": "@jupyter-widgets/controls",
1214
+ "_model_module_version": "1.5.0",
1215
+ "_model_name": "ProgressStyleModel",
1216
+ "_view_count": null,
1217
+ "_view_module": "@jupyter-widgets/base",
1218
+ "_view_module_version": "1.2.0",
1219
+ "_view_name": "StyleView",
1220
+ "bar_color": null,
1221
+ "description_width": ""
1222
+ }
1223
+ },
1224
+ "28a546a88db1498a9bb92b6ba2b0a4c5": {
1225
+ "model_module": "@jupyter-widgets/base",
1226
+ "model_module_version": "1.2.0",
1227
+ "model_name": "LayoutModel",
1228
+ "state": {
1229
+ "_model_module": "@jupyter-widgets/base",
1230
+ "_model_module_version": "1.2.0",
1231
+ "_model_name": "LayoutModel",
1232
+ "_view_count": null,
1233
+ "_view_module": "@jupyter-widgets/base",
1234
+ "_view_module_version": "1.2.0",
1235
+ "_view_name": "LayoutView",
1236
+ "align_content": null,
1237
+ "align_items": null,
1238
+ "align_self": null,
1239
+ "border": null,
1240
+ "bottom": null,
1241
+ "display": null,
1242
+ "flex": null,
1243
+ "flex_flow": null,
1244
+ "grid_area": null,
1245
+ "grid_auto_columns": null,
1246
+ "grid_auto_flow": null,
1247
+ "grid_auto_rows": null,
1248
+ "grid_column": null,
1249
+ "grid_gap": null,
1250
+ "grid_row": null,
1251
+ "grid_template_areas": null,
1252
+ "grid_template_columns": null,
1253
+ "grid_template_rows": null,
1254
+ "height": null,
1255
+ "justify_content": null,
1256
+ "justify_items": null,
1257
+ "left": null,
1258
+ "margin": null,
1259
+ "max_height": null,
1260
+ "max_width": null,
1261
+ "min_height": null,
1262
+ "min_width": null,
1263
+ "object_fit": null,
1264
+ "object_position": null,
1265
+ "order": null,
1266
+ "overflow": null,
1267
+ "overflow_x": null,
1268
+ "overflow_y": null,
1269
+ "padding": null,
1270
+ "right": null,
1271
+ "top": null,
1272
+ "visibility": null,
1273
+ "width": null
1274
+ }
1275
+ },
1276
+ "39395a57020841dcacc7c10703882292": {
1277
+ "model_module": "@jupyter-widgets/base",
1278
+ "model_module_version": "1.2.0",
1279
+ "model_name": "LayoutModel",
1280
+ "state": {
1281
+ "_model_module": "@jupyter-widgets/base",
1282
+ "_model_module_version": "1.2.0",
1283
+ "_model_name": "LayoutModel",
1284
+ "_view_count": null,
1285
+ "_view_module": "@jupyter-widgets/base",
1286
+ "_view_module_version": "1.2.0",
1287
+ "_view_name": "LayoutView",
1288
+ "align_content": null,
1289
+ "align_items": null,
1290
+ "align_self": null,
1291
+ "border": null,
1292
+ "bottom": null,
1293
+ "display": null,
1294
+ "flex": null,
1295
+ "flex_flow": null,
1296
+ "grid_area": null,
1297
+ "grid_auto_columns": null,
1298
+ "grid_auto_flow": null,
1299
+ "grid_auto_rows": null,
1300
+ "grid_column": null,
1301
+ "grid_gap": null,
1302
+ "grid_row": null,
1303
+ "grid_template_areas": null,
1304
+ "grid_template_columns": null,
1305
+ "grid_template_rows": null,
1306
+ "height": null,
1307
+ "justify_content": null,
1308
+ "justify_items": null,
1309
+ "left": null,
1310
+ "margin": null,
1311
+ "max_height": null,
1312
+ "max_width": null,
1313
+ "min_height": null,
1314
+ "min_width": null,
1315
+ "object_fit": null,
1316
+ "object_position": null,
1317
+ "order": null,
1318
+ "overflow": null,
1319
+ "overflow_x": null,
1320
+ "overflow_y": null,
1321
+ "padding": null,
1322
+ "right": null,
1323
+ "top": null,
1324
+ "visibility": null,
1325
+ "width": null
1326
+ }
1327
+ },
1328
+ "44f9d000be8742ff88b46f6d256e30a4": {
1329
+ "model_module": "@jupyter-widgets/controls",
1330
+ "model_module_version": "1.5.0",
1331
+ "model_name": "FloatProgressModel",
1332
+ "state": {
1333
+ "_dom_classes": [],
1334
+ "_model_module": "@jupyter-widgets/controls",
1335
+ "_model_module_version": "1.5.0",
1336
+ "_model_name": "FloatProgressModel",
1337
+ "_view_count": null,
1338
+ "_view_module": "@jupyter-widgets/controls",
1339
+ "_view_module_version": "1.5.0",
1340
+ "_view_name": "ProgressView",
1341
+ "bar_style": "success",
1342
+ "description": "",
1343
+ "description_tooltip": null,
1344
+ "layout": "IPY_MODEL_c7d90f7285a742289dfd6f57af84dc49",
1345
+ "max": 5.0,
1346
+ "min": 0.0,
1347
+ "orientation": "horizontal",
1348
+ "style": "IPY_MODEL_23b624e1517c46c083fb19112712a8a4",
1349
+ "value": 5.0
1350
+ }
1351
+ },
1352
+ "646f09bebd9245c186074b3a517485f3": {
1353
+ "model_module": "@jupyter-widgets/controls",
1354
+ "model_module_version": "1.5.0",
1355
+ "model_name": "HBoxModel",
1356
+ "state": {
1357
+ "_dom_classes": [],
1358
+ "_model_module": "@jupyter-widgets/controls",
1359
+ "_model_module_version": "1.5.0",
1360
+ "_model_name": "HBoxModel",
1361
+ "_view_count": null,
1362
+ "_view_module": "@jupyter-widgets/controls",
1363
+ "_view_module_version": "1.5.0",
1364
+ "_view_name": "HBoxView",
1365
+ "box_style": "",
1366
+ "children": [
1367
+ "IPY_MODEL_b3b939acd8834411b49a03c6620ad84e",
1368
+ "IPY_MODEL_44f9d000be8742ff88b46f6d256e30a4",
1369
+ "IPY_MODEL_979b377395c8447aadcc5181219086f6"
1370
+ ],
1371
+ "layout": "IPY_MODEL_39395a57020841dcacc7c10703882292"
1372
+ }
1373
+ },
1374
+ "979b377395c8447aadcc5181219086f6": {
1375
+ "model_module": "@jupyter-widgets/controls",
1376
+ "model_module_version": "1.5.0",
1377
+ "model_name": "HTMLModel",
1378
+ "state": {
1379
+ "_dom_classes": [],
1380
+ "_model_module": "@jupyter-widgets/controls",
1381
+ "_model_module_version": "1.5.0",
1382
+ "_model_name": "HTMLModel",
1383
+ "_view_count": null,
1384
+ "_view_module": "@jupyter-widgets/controls",
1385
+ "_view_module_version": "1.5.0",
1386
+ "_view_name": "HTMLView",
1387
+ "description": "",
1388
+ "description_tooltip": null,
1389
+ "layout": "IPY_MODEL_28a546a88db1498a9bb92b6ba2b0a4c5",
1390
+ "placeholder": "​",
1391
+ "style": "IPY_MODEL_1884fcbd64a9456dbd130120c3d4d8ba",
1392
+ "value": " 5/5 [21:54&lt;00:00, 262.32s/it]"
1393
+ }
1394
+ },
1395
+ "9b088d9442c54c9690f889b21203e3e1": {
1396
+ "model_module": "@jupyter-widgets/base",
1397
+ "model_module_version": "1.2.0",
1398
+ "model_name": "LayoutModel",
1399
+ "state": {
1400
+ "_model_module": "@jupyter-widgets/base",
1401
+ "_model_module_version": "1.2.0",
1402
+ "_model_name": "LayoutModel",
1403
+ "_view_count": null,
1404
+ "_view_module": "@jupyter-widgets/base",
1405
+ "_view_module_version": "1.2.0",
1406
+ "_view_name": "LayoutView",
1407
+ "align_content": null,
1408
+ "align_items": null,
1409
+ "align_self": null,
1410
+ "border": null,
1411
+ "bottom": null,
1412
+ "display": null,
1413
+ "flex": null,
1414
+ "flex_flow": null,
1415
+ "grid_area": null,
1416
+ "grid_auto_columns": null,
1417
+ "grid_auto_flow": null,
1418
+ "grid_auto_rows": null,
1419
+ "grid_column": null,
1420
+ "grid_gap": null,
1421
+ "grid_row": null,
1422
+ "grid_template_areas": null,
1423
+ "grid_template_columns": null,
1424
+ "grid_template_rows": null,
1425
+ "height": null,
1426
+ "justify_content": null,
1427
+ "justify_items": null,
1428
+ "left": null,
1429
+ "margin": null,
1430
+ "max_height": null,
1431
+ "max_width": null,
1432
+ "min_height": null,
1433
+ "min_width": null,
1434
+ "object_fit": null,
1435
+ "object_position": null,
1436
+ "order": null,
1437
+ "overflow": null,
1438
+ "overflow_x": null,
1439
+ "overflow_y": null,
1440
+ "padding": null,
1441
+ "right": null,
1442
+ "top": null,
1443
+ "visibility": null,
1444
+ "width": null
1445
+ }
1446
+ },
1447
+ "b3b939acd8834411b49a03c6620ad84e": {
1448
+ "model_module": "@jupyter-widgets/controls",
1449
+ "model_module_version": "1.5.0",
1450
+ "model_name": "HTMLModel",
1451
+ "state": {
1452
+ "_dom_classes": [],
1453
+ "_model_module": "@jupyter-widgets/controls",
1454
+ "_model_module_version": "1.5.0",
1455
+ "_model_name": "HTMLModel",
1456
+ "_view_count": null,
1457
+ "_view_module": "@jupyter-widgets/controls",
1458
+ "_view_module_version": "1.5.0",
1459
+ "_view_name": "HTMLView",
1460
+ "description": "",
1461
+ "description_tooltip": null,
1462
+ "layout": "IPY_MODEL_9b088d9442c54c9690f889b21203e3e1",
1463
+ "placeholder": "​",
1464
+ "style": "IPY_MODEL_ca62def05e9a4b55af75433b8de84833",
1465
+ "value": "100%"
1466
+ }
1467
+ },
1468
+ "c7d90f7285a742289dfd6f57af84dc49": {
1469
+ "model_module": "@jupyter-widgets/base",
1470
+ "model_module_version": "1.2.0",
1471
+ "model_name": "LayoutModel",
1472
+ "state": {
1473
+ "_model_module": "@jupyter-widgets/base",
1474
+ "_model_module_version": "1.2.0",
1475
+ "_model_name": "LayoutModel",
1476
+ "_view_count": null,
1477
+ "_view_module": "@jupyter-widgets/base",
1478
+ "_view_module_version": "1.2.0",
1479
+ "_view_name": "LayoutView",
1480
+ "align_content": null,
1481
+ "align_items": null,
1482
+ "align_self": null,
1483
+ "border": null,
1484
+ "bottom": null,
1485
+ "display": null,
1486
+ "flex": null,
1487
+ "flex_flow": null,
1488
+ "grid_area": null,
1489
+ "grid_auto_columns": null,
1490
+ "grid_auto_flow": null,
1491
+ "grid_auto_rows": null,
1492
+ "grid_column": null,
1493
+ "grid_gap": null,
1494
+ "grid_row": null,
1495
+ "grid_template_areas": null,
1496
+ "grid_template_columns": null,
1497
+ "grid_template_rows": null,
1498
+ "height": null,
1499
+ "justify_content": null,
1500
+ "justify_items": null,
1501
+ "left": null,
1502
+ "margin": null,
1503
+ "max_height": null,
1504
+ "max_width": null,
1505
+ "min_height": null,
1506
+ "min_width": null,
1507
+ "object_fit": null,
1508
+ "object_position": null,
1509
+ "order": null,
1510
+ "overflow": null,
1511
+ "overflow_x": null,
1512
+ "overflow_y": null,
1513
+ "padding": null,
1514
+ "right": null,
1515
+ "top": null,
1516
+ "visibility": null,
1517
+ "width": null
1518
+ }
1519
+ },
1520
+ "ca62def05e9a4b55af75433b8de84833": {
1521
+ "model_module": "@jupyter-widgets/controls",
1522
+ "model_module_version": "1.5.0",
1523
+ "model_name": "DescriptionStyleModel",
1524
+ "state": {
1525
+ "_model_module": "@jupyter-widgets/controls",
1526
+ "_model_module_version": "1.5.0",
1527
+ "_model_name": "DescriptionStyleModel",
1528
+ "_view_count": null,
1529
+ "_view_module": "@jupyter-widgets/base",
1530
+ "_view_module_version": "1.2.0",
1531
+ "_view_name": "StyleView",
1532
+ "description_width": ""
1533
+ }
1534
+ }
1535
+ },
1536
+ "version_major": 2,
1537
+ "version_minor": 0
1538
+ }
1539
+ }
1540
+ },
1541
+ "nbformat": 4,
1542
+ "nbformat_minor": 5
1543
+ }
TabPFN.py/Tabpfn_classifier_ver1.ipynb ADDED
@@ -0,0 +1,1442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "2e34fca0",
7
+ "metadata": {
8
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
9
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
10
+ "execution": {
11
+ "iopub.execute_input": "2023-07-19T15:49:20.883704Z",
12
+ "iopub.status.busy": "2023-07-19T15:49:20.883307Z",
13
+ "iopub.status.idle": "2023-07-19T15:49:20.899082Z",
14
+ "shell.execute_reply": "2023-07-19T15:49:20.897481Z"
15
+ },
16
+ "papermill": {
17
+ "duration": 0.026307,
18
+ "end_time": "2023-07-19T15:49:20.901426",
19
+ "exception": false,
20
+ "start_time": "2023-07-19T15:49:20.875119",
21
+ "status": "completed"
22
+ },
23
+ "tags": []
24
+ },
25
+ "outputs": [
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "/kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl\n",
31
+ "/kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_42.cpkt\n",
32
+ "/kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_100.cpkt\n",
33
+ "/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv\n",
34
+ "/kaggle/input/icr-identify-age-related-conditions/greeks.csv\n",
35
+ "/kaggle/input/icr-identify-age-related-conditions/train.csv\n",
36
+ "/kaggle/input/icr-identify-age-related-conditions/test.csv\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "# This Python 3 environment comes with many helpful analytics libraries installed\n",
42
+ "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
43
+ "# For example, here's several helpful packages to load\n",
44
+ "\n",
45
+ "import numpy as np # linear algebra\n",
46
+ "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
47
+ "\n",
48
+ "# Input data files are available in the read-only \"../input/\" directory\n",
49
+ "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
50
+ "\n",
51
+ "import os\n",
52
+ "for dirname, _, filenames in os.walk('/kaggle/input'):\n",
53
+ " for filename in filenames:\n",
54
+ " print(os.path.join(dirname, filename))\n",
55
+ "\n",
56
+ "# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
57
+ "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 2,
63
+ "id": "c7279ab3",
64
+ "metadata": {
65
+ "execution": {
66
+ "iopub.execute_input": "2023-07-19T15:49:20.914521Z",
67
+ "iopub.status.busy": "2023-07-19T15:49:20.914137Z",
68
+ "iopub.status.idle": "2023-07-19T15:49:52.144398Z",
69
+ "shell.execute_reply": "2023-07-19T15:49:52.143220Z"
70
+ },
71
+ "papermill": {
72
+ "duration": 31.239507,
73
+ "end_time": "2023-07-19T15:49:52.146861",
74
+ "exception": false,
75
+ "start_time": "2023-07-19T15:49:20.907354",
76
+ "status": "completed"
77
+ },
78
+ "tags": []
79
+ },
80
+ "outputs": [
81
+ {
82
+ "name": "stdout",
83
+ "output_type": "stream",
84
+ "text": [
85
+ "Processing /kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl\r\n",
86
+ "Requirement already satisfied: numpy>=1.21.2 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (1.23.5)\r\n",
87
+ "Requirement already satisfied: pyyaml>=5.4.1 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (6.0)\r\n",
88
+ "Requirement already satisfied: requests>=2.23.0 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (2.31.0)\r\n",
89
+ "Requirement already satisfied: scikit-learn>=0.24.2 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (1.2.2)\r\n",
90
+ "Requirement already satisfied: torch>=1.9.0 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (2.0.0+cpu)\r\n",
91
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (3.1.0)\r\n",
92
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (3.4)\r\n",
93
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (1.26.15)\r\n",
94
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (2023.5.7)\r\n",
95
+ "Requirement already satisfied: scipy>=1.3.2 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (1.11.1)\r\n",
96
+ "Requirement already satisfied: joblib>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (1.2.0)\r\n",
97
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (3.1.0)\r\n",
98
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.12.2)\r\n",
99
+ "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (4.6.3)\r\n",
100
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (1.12)\r\n",
101
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.1)\r\n",
102
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.1.2)\r\n",
103
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.9.0->tabpfn==0.1.9) (2.1.3)\r\n",
104
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.9.0->tabpfn==0.1.9) (1.3.0)\r\n",
105
+ "Installing collected packages: tabpfn\r\n",
106
+ "Successfully installed tabpfn-0.1.9\r\n"
107
+ ]
108
+ }
109
+ ],
110
+ "source": [
111
+ "!pip install /kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 3,
117
+ "id": "5dda34bb",
118
+ "metadata": {
119
+ "execution": {
120
+ "iopub.execute_input": "2023-07-19T15:49:52.165897Z",
121
+ "iopub.status.busy": "2023-07-19T15:49:52.165541Z",
122
+ "iopub.status.idle": "2023-07-19T15:49:53.511008Z",
123
+ "shell.execute_reply": "2023-07-19T15:49:53.509666Z"
124
+ },
125
+ "papermill": {
126
+ "duration": 1.356566,
127
+ "end_time": "2023-07-19T15:49:53.513451",
128
+ "exception": false,
129
+ "start_time": "2023-07-19T15:49:52.156885",
130
+ "status": "completed"
131
+ },
132
+ "tags": []
133
+ },
134
+ "outputs": [],
135
+ "source": [
136
+ "!mkdir /opt/conda/lib/python3.10/site-packages/tabpfn/models_diff\n",
137
+ "!cp /kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_100.cpkt /opt/conda/lib/python3.10/site-packages/tabpfn/models_diff/"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 4,
143
+ "id": "048aa160",
144
+ "metadata": {
145
+ "execution": {
146
+ "iopub.execute_input": "2023-07-19T15:49:53.527449Z",
147
+ "iopub.status.busy": "2023-07-19T15:49:53.527029Z",
148
+ "iopub.status.idle": "2023-07-19T15:49:58.980538Z",
149
+ "shell.execute_reply": "2023-07-19T15:49:58.979590Z"
150
+ },
151
+ "papermill": {
152
+ "duration": 5.463371,
153
+ "end_time": "2023-07-19T15:49:58.982884",
154
+ "exception": false,
155
+ "start_time": "2023-07-19T15:49:53.519513",
156
+ "status": "completed"
157
+ },
158
+ "tags": []
159
+ },
160
+ "outputs": [
161
+ {
162
+ "name": "stderr",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
166
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
167
+ ]
168
+ }
169
+ ],
170
+ "source": [
171
+ "from sklearn.preprocessing import LabelEncoder,normalize\n",
172
+ "from sklearn.ensemble import GradientBoostingClassifier,RandomForestClassifier\n",
173
+ "from sklearn.metrics import accuracy_score\n",
174
+ "from sklearn.impute import SimpleImputer\n",
175
+ "import imblearn\n",
176
+ "from imblearn.over_sampling import RandomOverSampler\n",
177
+ "from imblearn.under_sampling import RandomUnderSampler\n",
178
+ "import xgboost\n",
179
+ "import inspect\n",
180
+ "from collections import defaultdict\n",
181
+ "from tabpfn import TabPFNClassifier\n",
182
+ "import warnings\n",
183
+ "warnings.filterwarnings('ignore')\n",
184
+ "from sklearn.model_selection import KFold as KF, GridSearchCV\n",
185
+ "from tqdm.notebook import tqdm\n",
186
+ "from datetime import datetime"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": 5,
192
+ "id": "c860d4a3",
193
+ "metadata": {
194
+ "execution": {
195
+ "iopub.execute_input": "2023-07-19T15:49:58.998778Z",
196
+ "iopub.status.busy": "2023-07-19T15:49:58.997666Z",
197
+ "iopub.status.idle": "2023-07-19T15:49:59.003664Z",
198
+ "shell.execute_reply": "2023-07-19T15:49:59.002318Z"
199
+ },
200
+ "papermill": {
201
+ "duration": 0.015543,
202
+ "end_time": "2023-07-19T15:49:59.006138",
203
+ "exception": false,
204
+ "start_time": "2023-07-19T15:49:58.990595",
205
+ "status": "completed"
206
+ },
207
+ "tags": []
208
+ },
209
+ "outputs": [],
210
+ "source": [
211
+ "pd.set_option('display.max_columns', None)\n",
212
+ "pd.set_option('display.max_rows', None)"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 6,
218
+ "id": "f7c35d8a",
219
+ "metadata": {
220
+ "execution": {
221
+ "iopub.execute_input": "2023-07-19T15:49:59.021171Z",
222
+ "iopub.status.busy": "2023-07-19T15:49:59.020777Z",
223
+ "iopub.status.idle": "2023-07-19T15:49:59.074163Z",
224
+ "shell.execute_reply": "2023-07-19T15:49:59.073110Z"
225
+ },
226
+ "papermill": {
227
+ "duration": 0.064133,
228
+ "end_time": "2023-07-19T15:49:59.077033",
229
+ "exception": false,
230
+ "start_time": "2023-07-19T15:49:59.012900",
231
+ "status": "completed"
232
+ },
233
+ "tags": []
234
+ },
235
+ "outputs": [],
236
+ "source": [
237
+ "train_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/train.csv\")\n",
238
+ "test_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/test.csv\")\n",
239
+ "greeks_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/greeks.csv\")\n",
240
+ "sample_submission = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv\")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 7,
246
+ "id": "e6f4a8d9",
247
+ "metadata": {
248
+ "execution": {
249
+ "iopub.execute_input": "2023-07-19T15:49:59.092298Z",
250
+ "iopub.status.busy": "2023-07-19T15:49:59.091276Z",
251
+ "iopub.status.idle": "2023-07-19T15:49:59.108088Z",
252
+ "shell.execute_reply": "2023-07-19T15:49:59.106443Z"
253
+ },
254
+ "papermill": {
255
+ "duration": 0.027512,
256
+ "end_time": "2023-07-19T15:49:59.110901",
257
+ "exception": false,
258
+ "start_time": "2023-07-19T15:49:59.083389",
259
+ "status": "completed"
260
+ },
261
+ "tags": []
262
+ },
263
+ "outputs": [],
264
+ "source": [
265
+ "first_category = train_df.EJ.unique()[0]\n",
266
+ "train_df.EJ = train_df.EJ.eq(first_category).astype('int')\n",
267
+ "test_df.EJ = test_df.EJ.eq(first_category).astype('int')"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 8,
273
+ "id": "6cc16dc5",
274
+ "metadata": {
275
+ "execution": {
276
+ "iopub.execute_input": "2023-07-19T15:49:59.126235Z",
277
+ "iopub.status.busy": "2023-07-19T15:49:59.124902Z",
278
+ "iopub.status.idle": "2023-07-19T15:49:59.131980Z",
279
+ "shell.execute_reply": "2023-07-19T15:49:59.131296Z"
280
+ },
281
+ "papermill": {
282
+ "duration": 0.016797,
283
+ "end_time": "2023-07-19T15:49:59.134213",
284
+ "exception": false,
285
+ "start_time": "2023-07-19T15:49:59.117416",
286
+ "status": "completed"
287
+ },
288
+ "tags": []
289
+ },
290
+ "outputs": [],
291
+ "source": [
292
+ "def random_under_sampler(df):\n",
293
+ " neg, pos = np.bincount(df['Class'])\n",
294
+ " one_df = df.loc[df['Class'] == 1] \n",
295
+ " zero_df = df.loc[df['Class'] == 0]\n",
296
+ " zero_df = zero_df.sample(n=pos)\n",
297
+ " undersampled_df = pd.concat([zero_df, one_df])\n",
298
+ " return undersampled_df.sample(frac = 1)"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 9,
304
+ "id": "72c5ebbc",
305
+ "metadata": {
306
+ "execution": {
307
+ "iopub.execute_input": "2023-07-19T15:49:59.147644Z",
308
+ "iopub.status.busy": "2023-07-19T15:49:59.147321Z",
309
+ "iopub.status.idle": "2023-07-19T15:49:59.161115Z",
310
+ "shell.execute_reply": "2023-07-19T15:49:59.160423Z"
311
+ },
312
+ "papermill": {
313
+ "duration": 0.023061,
314
+ "end_time": "2023-07-19T15:49:59.163502",
315
+ "exception": false,
316
+ "start_time": "2023-07-19T15:49:59.140441",
317
+ "status": "completed"
318
+ },
319
+ "tags": []
320
+ },
321
+ "outputs": [],
322
+ "source": [
323
+ "train_good = random_under_sampler(train_df)"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": 10,
329
+ "id": "e005e9ad",
330
+ "metadata": {
331
+ "execution": {
332
+ "iopub.execute_input": "2023-07-19T15:49:59.178245Z",
333
+ "iopub.status.busy": "2023-07-19T15:49:59.177838Z",
334
+ "iopub.status.idle": "2023-07-19T15:49:59.185746Z",
335
+ "shell.execute_reply": "2023-07-19T15:49:59.184237Z"
336
+ },
337
+ "papermill": {
338
+ "duration": 0.018132,
339
+ "end_time": "2023-07-19T15:49:59.188463",
340
+ "exception": false,
341
+ "start_time": "2023-07-19T15:49:59.170331",
342
+ "status": "completed"
343
+ },
344
+ "tags": []
345
+ },
346
+ "outputs": [
347
+ {
348
+ "data": {
349
+ "text/plain": [
350
+ "(216, 58)"
351
+ ]
352
+ },
353
+ "execution_count": 10,
354
+ "metadata": {},
355
+ "output_type": "execute_result"
356
+ }
357
+ ],
358
+ "source": [
359
+ "train_good.shape"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 11,
365
+ "id": "c4cd1123",
366
+ "metadata": {
367
+ "execution": {
368
+ "iopub.execute_input": "2023-07-19T15:49:59.204657Z",
369
+ "iopub.status.busy": "2023-07-19T15:49:59.204301Z",
370
+ "iopub.status.idle": "2023-07-19T15:49:59.215752Z",
371
+ "shell.execute_reply": "2023-07-19T15:49:59.213946Z"
372
+ },
373
+ "papermill": {
374
+ "duration": 0.023495,
375
+ "end_time": "2023-07-19T15:49:59.218315",
376
+ "exception": false,
377
+ "start_time": "2023-07-19T15:49:59.194820",
378
+ "status": "completed"
379
+ },
380
+ "tags": []
381
+ },
382
+ "outputs": [],
383
+ "source": [
384
+ "predictor_columns = [n for n in train_df.columns if n != 'Class' and n != 'Id']\n",
385
+ "x= train_df[predictor_columns]\n",
386
+ "y = train_df['Class']"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": 12,
392
+ "id": "bd89b899",
393
+ "metadata": {
394
+ "execution": {
395
+ "iopub.execute_input": "2023-07-19T15:49:59.232391Z",
396
+ "iopub.status.busy": "2023-07-19T15:49:59.232003Z",
397
+ "iopub.status.idle": "2023-07-19T15:49:59.236624Z",
398
+ "shell.execute_reply": "2023-07-19T15:49:59.235735Z"
399
+ },
400
+ "papermill": {
401
+ "duration": 0.013656,
402
+ "end_time": "2023-07-19T15:49:59.238300",
403
+ "exception": false,
404
+ "start_time": "2023-07-19T15:49:59.224644",
405
+ "status": "completed"
406
+ },
407
+ "tags": []
408
+ },
409
+ "outputs": [],
410
+ "source": [
411
+ "cv_outer = KF(n_splits = 10, shuffle=True, random_state=42)\n",
412
+ "cv_inner = KF(n_splits = 5, shuffle=True, random_state=42)"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": 13,
418
+ "id": "e9063643",
419
+ "metadata": {
420
+ "execution": {
421
+ "iopub.execute_input": "2023-07-19T15:49:59.252426Z",
422
+ "iopub.status.busy": "2023-07-19T15:49:59.251721Z",
423
+ "iopub.status.idle": "2023-07-19T15:49:59.257645Z",
424
+ "shell.execute_reply": "2023-07-19T15:49:59.256965Z"
425
+ },
426
+ "papermill": {
427
+ "duration": 0.01509,
428
+ "end_time": "2023-07-19T15:49:59.259617",
429
+ "exception": false,
430
+ "start_time": "2023-07-19T15:49:59.244527",
431
+ "status": "completed"
432
+ },
433
+ "tags": []
434
+ },
435
+ "outputs": [],
436
+ "source": [
437
+ "def balanced_log_loss(y_true, y_pred):\n",
438
+ " N_0 = np.sum(1 - y_true)\n",
439
+ " N_1 = np.sum(y_true)\n",
440
+ " w_0 = 1 / N_0\n",
441
+ " w_1 = 1 / N_1\n",
442
+ " p_1 = np.clip(y_pred, 1e-15, 1 - 1e-15)\n",
443
+ " p_0 = 1 - p_1\n",
444
+ " log_loss_0 = -np.sum((1 - y_true) * np.log(p_0))\n",
445
+ " log_loss_1 = -np.sum(y_true * np.log(p_1))\n",
446
+ " balanced_log_loss = 2*(w_0 * log_loss_0 + w_1 * log_loss_1) / (w_0 + w_1)\n",
447
+ " return balanced_log_loss/(N_0+N_1)"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 14,
453
+ "id": "9b67e0f4",
454
+ "metadata": {
455
+ "execution": {
456
+ "iopub.execute_input": "2023-07-19T15:49:59.274215Z",
457
+ "iopub.status.busy": "2023-07-19T15:49:59.273585Z",
458
+ "iopub.status.idle": "2023-07-19T15:49:59.283466Z",
459
+ "shell.execute_reply": "2023-07-19T15:49:59.282130Z"
460
+ },
461
+ "papermill": {
462
+ "duration": 0.020008,
463
+ "end_time": "2023-07-19T15:49:59.285916",
464
+ "exception": false,
465
+ "start_time": "2023-07-19T15:49:59.265908",
466
+ "status": "completed"
467
+ },
468
+ "tags": []
469
+ },
470
+ "outputs": [],
471
+ "source": [
472
+ "class Ensemble():\n",
473
+ " def __init__(self):\n",
474
+ " self.imputer = SimpleImputer(missing_values=np.nan, strategy='median')\n",
475
+ " self.classifiers =[xgboost.XGBClassifier(),TabPFNClassifier(N_ensemble_configurations=64)]\n",
476
+ " \n",
477
+ " def fit(self,X,y):\n",
478
+ " y = y.values\n",
479
+ " unique_classes, y = np.unique(y, return_inverse=True)\n",
480
+ " self.classes_ = unique_classes\n",
481
+ " first_category = X.EJ.unique()[0]\n",
482
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n",
483
+ " X = self.imputer.fit_transform(X)\n",
484
+ "# X = normalize(X,axis=0)\n",
485
+ " for classifier in self.classifiers:\n",
486
+ " if classifier==self.classifiers[1]:\n",
487
+ " classifier.fit(X,y,overwrite_warning =True)\n",
488
+ " else :\n",
489
+ " classifier.fit(X, y)\n",
490
+ " \n",
491
+ " def predict_proba(self, x):\n",
492
+ " x = self.imputer.transform(x)\n",
493
+ "# x = normalize(x,axis=0)\n",
494
+ " probabilities = np.stack([classifier.predict_proba(x) for classifier in self.classifiers])\n",
495
+ " averaged_probabilities = np.mean(probabilities, axis=0)\n",
496
+ " class_0_est_instances = averaged_probabilities[:, 0].sum()\n",
497
+ " others_est_instances = averaged_probabilities[:, 1:].sum()\n",
498
+ " # Weighted probabilities based on class imbalance\n",
499
+ " new_probabilities = averaged_probabilities * np.array([[1/(class_0_est_instances if i==0 else others_est_instances) for i in range(averaged_probabilities.shape[1])]])\n",
500
+ " return new_probabilities / np.sum(new_probabilities, axis=1, keepdims=1) "
501
+ ]
502
+ },
503
+ {
504
+ "cell_type": "code",
505
+ "execution_count": 15,
506
+ "id": "d168bc26",
507
+ "metadata": {
508
+ "execution": {
509
+ "iopub.execute_input": "2023-07-19T15:49:59.300477Z",
510
+ "iopub.status.busy": "2023-07-19T15:49:59.299679Z",
511
+ "iopub.status.idle": "2023-07-19T15:49:59.308853Z",
512
+ "shell.execute_reply": "2023-07-19T15:49:59.307755Z"
513
+ },
514
+ "papermill": {
515
+ "duration": 0.018791,
516
+ "end_time": "2023-07-19T15:49:59.311141",
517
+ "exception": false,
518
+ "start_time": "2023-07-19T15:49:59.292350",
519
+ "status": "completed"
520
+ },
521
+ "tags": []
522
+ },
523
+ "outputs": [],
524
+ "source": [
525
+ "def training(model, x,y,y_meta):\n",
526
+ " outer_results = list()\n",
527
+ " best_loss = np.inf\n",
528
+ " split = 0\n",
529
+ " splits = 5\n",
530
+ " for train_idx,val_idx in tqdm(cv_inner.split(x), total = splits):\n",
531
+ " split+=1\n",
532
+ " x_train, x_val = x.iloc[train_idx],x.iloc[val_idx]\n",
533
+ " y_train, y_val = y_meta.iloc[train_idx], y.iloc[val_idx]\n",
534
+ " \n",
535
+ " model.fit(x_train, y_train)\n",
536
+ " y_pred = model.predict_proba(x_val)\n",
537
+ " probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
538
+ " p0 = probabilities[:,:1]\n",
539
+ " p0[p0 > 0.86] = 1\n",
540
+ " p0[p0 < 0.14] = 0\n",
541
+ " y_p = np.empty((y_pred.shape[0],))\n",
542
+ " for i in range(y_pred.shape[0]):\n",
543
+ " if p0[i]>=0.5:\n",
544
+ " y_p[i]= False\n",
545
+ " else :\n",
546
+ " y_p[i]=True\n",
547
+ " y_p = y_p.astype(int)\n",
548
+ " loss = balanced_log_loss(y_val,y_p)\n",
549
+ "\n",
550
+ " if loss<best_loss:\n",
551
+ " best_model = model\n",
552
+ " best_loss = loss\n",
553
+ " print('best_model_saved')\n",
554
+ " outer_results.append(loss)\n",
555
+ " print('>val_loss=',loss, 'split =',split)\n",
556
+ " print('LOSS:', np.mean(outer_results))\n",
557
+ " return best_model"
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "code",
562
+ "execution_count": 16,
563
+ "id": "079c1769",
564
+ "metadata": {
565
+ "execution": {
566
+ "iopub.execute_input": "2023-07-19T15:49:59.325185Z",
567
+ "iopub.status.busy": "2023-07-19T15:49:59.324666Z",
568
+ "iopub.status.idle": "2023-07-19T15:49:59.342172Z",
569
+ "shell.execute_reply": "2023-07-19T15:49:59.341421Z"
570
+ },
571
+ "papermill": {
572
+ "duration": 0.027008,
573
+ "end_time": "2023-07-19T15:49:59.344419",
574
+ "exception": false,
575
+ "start_time": "2023-07-19T15:49:59.317411",
576
+ "status": "completed"
577
+ },
578
+ "tags": []
579
+ },
580
+ "outputs": [],
581
+ "source": [
582
+ "times = greeks_df.Epsilon.copy()\n",
583
+ "times[greeks_df.Epsilon != 'Unknown'] = greeks_df.Epsilon[greeks_df.Epsilon != 'Unknown'].map(lambda x: datetime.strptime(x,'%m/%d/%Y').toordinal())\n",
584
+ "times[greeks_df.Epsilon == 'Unknown'] = np.nan"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 17,
590
+ "id": "06e0bc9e",
591
+ "metadata": {
592
+ "execution": {
593
+ "iopub.execute_input": "2023-07-19T15:49:59.358648Z",
594
+ "iopub.status.busy": "2023-07-19T15:49:59.358118Z",
595
+ "iopub.status.idle": "2023-07-19T15:49:59.369274Z",
596
+ "shell.execute_reply": "2023-07-19T15:49:59.368592Z"
597
+ },
598
+ "papermill": {
599
+ "duration": 0.020368,
600
+ "end_time": "2023-07-19T15:49:59.371199",
601
+ "exception": false,
602
+ "start_time": "2023-07-19T15:49:59.350831",
603
+ "status": "completed"
604
+ },
605
+ "tags": []
606
+ },
607
+ "outputs": [],
608
+ "source": [
609
+ "train_pred_and_time = pd.concat((train_df, times), axis=1)\n",
610
+ "test_predictors = test_df[predictor_columns]\n",
611
+ "first_category = test_predictors.EJ.unique()[0]\n",
612
+ "test_predictors.EJ = test_predictors.EJ.eq(first_category).astype('int')\n",
613
+ "test_pred_and_time = np.concatenate((test_predictors, np.zeros((len(test_predictors), 1)) + train_pred_and_time.Epsilon.max() + 1), axis=1)"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": 18,
619
+ "id": "f2cfb985",
620
+ "metadata": {
621
+ "execution": {
622
+ "iopub.execute_input": "2023-07-19T15:49:59.385041Z",
623
+ "iopub.status.busy": "2023-07-19T15:49:59.384583Z",
624
+ "iopub.status.idle": "2023-07-19T15:49:59.412622Z",
625
+ "shell.execute_reply": "2023-07-19T15:49:59.411592Z"
626
+ },
627
+ "papermill": {
628
+ "duration": 0.037351,
629
+ "end_time": "2023-07-19T15:49:59.414704",
630
+ "exception": false,
631
+ "start_time": "2023-07-19T15:49:59.377353",
632
+ "status": "completed"
633
+ },
634
+ "tags": []
635
+ },
636
+ "outputs": [
637
+ {
638
+ "name": "stdout",
639
+ "output_type": "stream",
640
+ "text": [
641
+ "Original dataset shape\n",
642
+ "A 509\n",
643
+ "B 61\n",
644
+ "G 29\n",
645
+ "D 18\n",
646
+ "Name: Alpha, dtype: int64\n",
647
+ "Resample dataset shape\n",
648
+ "B 509\n",
649
+ "A 509\n",
650
+ "D 509\n",
651
+ "G 509\n",
652
+ "Name: Alpha, dtype: int64\n"
653
+ ]
654
+ }
655
+ ],
656
+ "source": [
657
+ "ros = RandomOverSampler(random_state=42)\n",
658
+ "\n",
659
+ "train_ros, y_ros = ros.fit_resample(train_pred_and_time, greeks_df.Alpha)\n",
660
+ "print('Original dataset shape')\n",
661
+ "print(greeks_df.Alpha.value_counts())\n",
662
+ "print('Resample dataset shape')\n",
663
+ "print( y_ros.value_counts())"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "code",
668
+ "execution_count": 19,
669
+ "id": "011c35ed",
670
+ "metadata": {
671
+ "execution": {
672
+ "iopub.execute_input": "2023-07-19T15:49:59.429884Z",
673
+ "iopub.status.busy": "2023-07-19T15:49:59.429313Z",
674
+ "iopub.status.idle": "2023-07-19T15:49:59.435592Z",
675
+ "shell.execute_reply": "2023-07-19T15:49:59.434689Z"
676
+ },
677
+ "papermill": {
678
+ "duration": 0.016651,
679
+ "end_time": "2023-07-19T15:49:59.437651",
680
+ "exception": false,
681
+ "start_time": "2023-07-19T15:49:59.421000",
682
+ "status": "completed"
683
+ },
684
+ "tags": []
685
+ },
686
+ "outputs": [],
687
+ "source": [
688
+ "x_ros = train_ros.drop(['Class', 'Id'],axis=1)\n",
689
+ "y_ = train_ros.Class"
690
+ ]
691
+ },
692
+ {
693
+ "cell_type": "code",
694
+ "execution_count": 20,
695
+ "id": "d25f1095",
696
+ "metadata": {
697
+ "execution": {
698
+ "iopub.execute_input": "2023-07-19T15:49:59.451738Z",
699
+ "iopub.status.busy": "2023-07-19T15:49:59.451218Z",
700
+ "iopub.status.idle": "2023-07-19T15:49:59.867640Z",
701
+ "shell.execute_reply": "2023-07-19T15:49:59.866649Z"
702
+ },
703
+ "papermill": {
704
+ "duration": 0.425633,
705
+ "end_time": "2023-07-19T15:49:59.869598",
706
+ "exception": false,
707
+ "start_time": "2023-07-19T15:49:59.443965",
708
+ "status": "completed"
709
+ },
710
+ "tags": []
711
+ },
712
+ "outputs": [
713
+ {
714
+ "name": "stdout",
715
+ "output_type": "stream",
716
+ "text": [
717
+ "Loading model that can be used for inference only\n",
718
+ "Using a Transformer with 25.82 M parameters\n"
719
+ ]
720
+ }
721
+ ],
722
+ "source": [
723
+ "yt = Ensemble()"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": 21,
729
+ "id": "edb43c1a",
730
+ "metadata": {
731
+ "execution": {
732
+ "iopub.execute_input": "2023-07-19T15:49:59.884325Z",
733
+ "iopub.status.busy": "2023-07-19T15:49:59.883762Z",
734
+ "iopub.status.idle": "2023-07-19T16:01:07.842831Z",
735
+ "shell.execute_reply": "2023-07-19T16:01:07.841617Z"
736
+ },
737
+ "papermill": {
738
+ "duration": 667.976046,
739
+ "end_time": "2023-07-19T16:01:07.852065",
740
+ "exception": false,
741
+ "start_time": "2023-07-19T15:49:59.876019",
742
+ "status": "completed"
743
+ },
744
+ "tags": []
745
+ },
746
+ "outputs": [
747
+ {
748
+ "data": {
749
+ "application/vnd.jupyter.widget-view+json": {
750
+ "model_id": "bd3cb9f78122483eb70d16ca6c7b8962",
751
+ "version_major": 2,
752
+ "version_minor": 0
753
+ },
754
+ "text/plain": [
755
+ " 0%| | 0/5 [00:00<?, ?it/s]"
756
+ ]
757
+ },
758
+ "metadata": {},
759
+ "output_type": "display_data"
760
+ },
761
+ {
762
+ "name": "stdout",
763
+ "output_type": "stream",
764
+ "text": [
765
+ "best_model_saved\n",
766
+ ">val_loss= 0.12283393999583053 split = 1\n"
767
+ ]
768
+ },
769
+ {
770
+ "name": "stderr",
771
+ "output_type": "stream",
772
+ "text": [
773
+ "/tmp/ipykernel_20/2226499128.py:11: SettingWithCopyWarning: \n",
774
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
775
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
776
+ "\n",
777
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
778
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
779
+ ]
780
+ },
781
+ {
782
+ "name": "stdout",
783
+ "output_type": "stream",
784
+ "text": [
785
+ "best_model_saved\n",
786
+ ">val_loss= 7.882664572210757e-16 split = 2\n"
787
+ ]
788
+ },
789
+ {
790
+ "name": "stderr",
791
+ "output_type": "stream",
792
+ "text": [
793
+ "/tmp/ipykernel_20/2226499128.py:11: SettingWithCopyWarning: \n",
794
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
795
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
796
+ "\n",
797
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
798
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
799
+ ]
800
+ },
801
+ {
802
+ "name": "stdout",
803
+ "output_type": "stream",
804
+ "text": [
805
+ ">val_loss= 7.927542919637485e-16 split = 3\n"
806
+ ]
807
+ },
808
+ {
809
+ "name": "stderr",
810
+ "output_type": "stream",
811
+ "text": [
812
+ "/tmp/ipykernel_20/2226499128.py:11: SettingWithCopyWarning: \n",
813
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
814
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
815
+ "\n",
816
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
817
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
818
+ ]
819
+ },
820
+ {
821
+ "name": "stdout",
822
+ "output_type": "stream",
823
+ "text": [
824
+ "best_model_saved\n",
825
+ ">val_loss= 6.883759419809394e-16 split = 4\n"
826
+ ]
827
+ },
828
+ {
829
+ "name": "stderr",
830
+ "output_type": "stream",
831
+ "text": [
832
+ "/tmp/ipykernel_20/2226499128.py:11: SettingWithCopyWarning: \n",
833
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
834
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
835
+ "\n",
836
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
837
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
838
+ ]
839
+ },
840
+ {
841
+ "name": "stdout",
842
+ "output_type": "stream",
843
+ "text": [
844
+ ">val_loss= 0.13386381920254847 split = 5\n",
845
+ "LOSS: 0.051339551839676256\n"
846
+ ]
847
+ }
848
+ ],
849
+ "source": [
850
+ "m = training(yt,x_ros,y_,y_ros)"
851
+ ]
852
+ },
853
+ {
854
+ "cell_type": "code",
855
+ "execution_count": 22,
856
+ "id": "5a469b9e",
857
+ "metadata": {
858
+ "execution": {
859
+ "iopub.execute_input": "2023-07-19T16:01:07.867409Z",
860
+ "iopub.status.busy": "2023-07-19T16:01:07.867041Z",
861
+ "iopub.status.idle": "2023-07-19T16:01:07.875651Z",
862
+ "shell.execute_reply": "2023-07-19T16:01:07.874672Z"
863
+ },
864
+ "papermill": {
865
+ "duration": 0.019324,
866
+ "end_time": "2023-07-19T16:01:07.878246",
867
+ "exception": false,
868
+ "start_time": "2023-07-19T16:01:07.858922",
869
+ "status": "completed"
870
+ },
871
+ "tags": []
872
+ },
873
+ "outputs": [
874
+ {
875
+ "data": {
876
+ "text/plain": [
877
+ "1 0.75\n",
878
+ "0 0.25\n",
879
+ "Name: Class, dtype: float64"
880
+ ]
881
+ },
882
+ "execution_count": 22,
883
+ "metadata": {},
884
+ "output_type": "execute_result"
885
+ }
886
+ ],
887
+ "source": [
888
+ "y_.value_counts()/y_.shape[0]"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": 23,
894
+ "id": "ae189a8b",
895
+ "metadata": {
896
+ "execution": {
897
+ "iopub.execute_input": "2023-07-19T16:01:07.894782Z",
898
+ "iopub.status.busy": "2023-07-19T16:01:07.894387Z",
899
+ "iopub.status.idle": "2023-07-19T16:02:58.756981Z",
900
+ "shell.execute_reply": "2023-07-19T16:02:58.756252Z"
901
+ },
902
+ "papermill": {
903
+ "duration": 110.872949,
904
+ "end_time": "2023-07-19T16:02:58.758958",
905
+ "exception": false,
906
+ "start_time": "2023-07-19T16:01:07.886009",
907
+ "status": "completed"
908
+ },
909
+ "tags": []
910
+ },
911
+ "outputs": [
912
+ {
913
+ "name": "stderr",
914
+ "output_type": "stream",
915
+ "text": [
916
+ "/opt/conda/lib/python3.10/site-packages/sklearn/base.py:439: UserWarning: X does not have valid feature names, but SimpleImputer was fitted with feature names\n",
917
+ " warnings.warn(\n"
918
+ ]
919
+ }
920
+ ],
921
+ "source": [
922
+ "y_pred = m.predict_proba(test_pred_and_time)\n",
923
+ "probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
924
+ "p0 = probabilities[:,:1]\n",
925
+ "p0[p0 > 0.70] = 1 \n",
926
+ "p0[p0 < 0.26] = 0"
927
+ ]
928
+ },
929
+ {
930
+ "cell_type": "code",
931
+ "execution_count": 24,
932
+ "id": "351548b7",
933
+ "metadata": {
934
+ "execution": {
935
+ "iopub.execute_input": "2023-07-19T16:02:58.775878Z",
936
+ "iopub.status.busy": "2023-07-19T16:02:58.774253Z",
937
+ "iopub.status.idle": "2023-07-19T16:02:58.785182Z",
938
+ "shell.execute_reply": "2023-07-19T16:02:58.784481Z"
939
+ },
940
+ "papermill": {
941
+ "duration": 0.021177,
942
+ "end_time": "2023-07-19T16:02:58.787220",
943
+ "exception": false,
944
+ "start_time": "2023-07-19T16:02:58.766043",
945
+ "status": "completed"
946
+ },
947
+ "tags": []
948
+ },
949
+ "outputs": [],
950
+ "source": [
951
+ "submission = pd.DataFrame(test_df[\"Id\"], columns=[\"Id\"])\n",
952
+ "submission[\"class_0\"] = p0\n",
953
+ "submission[\"class_1\"] = 1 - p0\n",
954
+ "submission.to_csv('submission.csv', index=False)"
955
+ ]
956
+ },
957
+ {
958
+ "cell_type": "code",
959
+ "execution_count": 25,
960
+ "id": "37c7c730",
961
+ "metadata": {
962
+ "execution": {
963
+ "iopub.execute_input": "2023-07-19T16:02:58.803532Z",
964
+ "iopub.status.busy": "2023-07-19T16:02:58.802592Z",
965
+ "iopub.status.idle": "2023-07-19T16:02:58.821606Z",
966
+ "shell.execute_reply": "2023-07-19T16:02:58.820310Z"
967
+ },
968
+ "papermill": {
969
+ "duration": 0.029582,
970
+ "end_time": "2023-07-19T16:02:58.824105",
971
+ "exception": false,
972
+ "start_time": "2023-07-19T16:02:58.794523",
973
+ "status": "completed"
974
+ },
975
+ "tags": []
976
+ },
977
+ "outputs": [
978
+ {
979
+ "data": {
980
+ "text/html": [
981
+ "<div>\n",
982
+ "<style scoped>\n",
983
+ " .dataframe tbody tr th:only-of-type {\n",
984
+ " vertical-align: middle;\n",
985
+ " }\n",
986
+ "\n",
987
+ " .dataframe tbody tr th {\n",
988
+ " vertical-align: top;\n",
989
+ " }\n",
990
+ "\n",
991
+ " .dataframe thead th {\n",
992
+ " text-align: right;\n",
993
+ " }\n",
994
+ "</style>\n",
995
+ "<table border=\"1\" class=\"dataframe\">\n",
996
+ " <thead>\n",
997
+ " <tr style=\"text-align: right;\">\n",
998
+ " <th></th>\n",
999
+ " <th>Id</th>\n",
1000
+ " <th>class_0</th>\n",
1001
+ " <th>class_1</th>\n",
1002
+ " </tr>\n",
1003
+ " </thead>\n",
1004
+ " <tbody>\n",
1005
+ " <tr>\n",
1006
+ " <th>0</th>\n",
1007
+ " <td>00eed32682bb</td>\n",
1008
+ " <td>0.5</td>\n",
1009
+ " <td>0.5</td>\n",
1010
+ " </tr>\n",
1011
+ " <tr>\n",
1012
+ " <th>1</th>\n",
1013
+ " <td>010ebe33f668</td>\n",
1014
+ " <td>0.5</td>\n",
1015
+ " <td>0.5</td>\n",
1016
+ " </tr>\n",
1017
+ " <tr>\n",
1018
+ " <th>2</th>\n",
1019
+ " <td>02fa521e1838</td>\n",
1020
+ " <td>0.5</td>\n",
1021
+ " <td>0.5</td>\n",
1022
+ " </tr>\n",
1023
+ " <tr>\n",
1024
+ " <th>3</th>\n",
1025
+ " <td>040e15f562a2</td>\n",
1026
+ " <td>0.5</td>\n",
1027
+ " <td>0.5</td>\n",
1028
+ " </tr>\n",
1029
+ " <tr>\n",
1030
+ " <th>4</th>\n",
1031
+ " <td>046e85c7cc7f</td>\n",
1032
+ " <td>0.5</td>\n",
1033
+ " <td>0.5</td>\n",
1034
+ " </tr>\n",
1035
+ " </tbody>\n",
1036
+ "</table>\n",
1037
+ "</div>"
1038
+ ],
1039
+ "text/plain": [
1040
+ " Id class_0 class_1\n",
1041
+ "0 00eed32682bb 0.5 0.5\n",
1042
+ "1 010ebe33f668 0.5 0.5\n",
1043
+ "2 02fa521e1838 0.5 0.5\n",
1044
+ "3 040e15f562a2 0.5 0.5\n",
1045
+ "4 046e85c7cc7f 0.5 0.5"
1046
+ ]
1047
+ },
1048
+ "execution_count": 25,
1049
+ "metadata": {},
1050
+ "output_type": "execute_result"
1051
+ }
1052
+ ],
1053
+ "source": [
1054
+ "submission_df = pd.read_csv('submission.csv')\n",
1055
+ "submission_df"
1056
+ ]
1057
+ }
1058
+ ],
1059
+ "metadata": {
1060
+ "kernelspec": {
1061
+ "display_name": "Python 3",
1062
+ "language": "python",
1063
+ "name": "python3"
1064
+ },
1065
+ "language_info": {
1066
+ "codemirror_mode": {
1067
+ "name": "ipython",
1068
+ "version": 3
1069
+ },
1070
+ "file_extension": ".py",
1071
+ "mimetype": "text/x-python",
1072
+ "name": "python",
1073
+ "nbconvert_exporter": "python",
1074
+ "pygments_lexer": "ipython3",
1075
+ "version": "3.10.12"
1076
+ },
1077
+ "papermill": {
1078
+ "default_parameters": {},
1079
+ "duration": 828.301857,
1080
+ "end_time": "2023-07-19T16:03:00.358486",
1081
+ "environment_variables": {},
1082
+ "exception": null,
1083
+ "input_path": "__notebook__.ipynb",
1084
+ "output_path": "__notebook__.ipynb",
1085
+ "parameters": {},
1086
+ "start_time": "2023-07-19T15:49:12.056629",
1087
+ "version": "2.4.0"
1088
+ },
1089
+ "widgets": {
1090
+ "application/vnd.jupyter.widget-state+json": {
1091
+ "state": {
1092
+ "34c26c71e174412d8c6bce15f6cf55ab": {
1093
+ "model_module": "@jupyter-widgets/controls",
1094
+ "model_module_version": "1.5.0",
1095
+ "model_name": "ProgressStyleModel",
1096
+ "state": {
1097
+ "_model_module": "@jupyter-widgets/controls",
1098
+ "_model_module_version": "1.5.0",
1099
+ "_model_name": "ProgressStyleModel",
1100
+ "_view_count": null,
1101
+ "_view_module": "@jupyter-widgets/base",
1102
+ "_view_module_version": "1.2.0",
1103
+ "_view_name": "StyleView",
1104
+ "bar_color": null,
1105
+ "description_width": ""
1106
+ }
1107
+ },
1108
+ "4bdcfdf8ea294ffa851124a60dd797b5": {
1109
+ "model_module": "@jupyter-widgets/base",
1110
+ "model_module_version": "1.2.0",
1111
+ "model_name": "LayoutModel",
1112
+ "state": {
1113
+ "_model_module": "@jupyter-widgets/base",
1114
+ "_model_module_version": "1.2.0",
1115
+ "_model_name": "LayoutModel",
1116
+ "_view_count": null,
1117
+ "_view_module": "@jupyter-widgets/base",
1118
+ "_view_module_version": "1.2.0",
1119
+ "_view_name": "LayoutView",
1120
+ "align_content": null,
1121
+ "align_items": null,
1122
+ "align_self": null,
1123
+ "border": null,
1124
+ "bottom": null,
1125
+ "display": null,
1126
+ "flex": null,
1127
+ "flex_flow": null,
1128
+ "grid_area": null,
1129
+ "grid_auto_columns": null,
1130
+ "grid_auto_flow": null,
1131
+ "grid_auto_rows": null,
1132
+ "grid_column": null,
1133
+ "grid_gap": null,
1134
+ "grid_row": null,
1135
+ "grid_template_areas": null,
1136
+ "grid_template_columns": null,
1137
+ "grid_template_rows": null,
1138
+ "height": null,
1139
+ "justify_content": null,
1140
+ "justify_items": null,
1141
+ "left": null,
1142
+ "margin": null,
1143
+ "max_height": null,
1144
+ "max_width": null,
1145
+ "min_height": null,
1146
+ "min_width": null,
1147
+ "object_fit": null,
1148
+ "object_position": null,
1149
+ "order": null,
1150
+ "overflow": null,
1151
+ "overflow_x": null,
1152
+ "overflow_y": null,
1153
+ "padding": null,
1154
+ "right": null,
1155
+ "top": null,
1156
+ "visibility": null,
1157
+ "width": null
1158
+ }
1159
+ },
1160
+ "86be6094b7ea4c018546c7a08ac21c32": {
1161
+ "model_module": "@jupyter-widgets/controls",
1162
+ "model_module_version": "1.5.0",
1163
+ "model_name": "DescriptionStyleModel",
1164
+ "state": {
1165
+ "_model_module": "@jupyter-widgets/controls",
1166
+ "_model_module_version": "1.5.0",
1167
+ "_model_name": "DescriptionStyleModel",
1168
+ "_view_count": null,
1169
+ "_view_module": "@jupyter-widgets/base",
1170
+ "_view_module_version": "1.2.0",
1171
+ "_view_name": "StyleView",
1172
+ "description_width": ""
1173
+ }
1174
+ },
1175
+ "92a276fc64f04f4a9220c8ecc22115b2": {
1176
+ "model_module": "@jupyter-widgets/controls",
1177
+ "model_module_version": "1.5.0",
1178
+ "model_name": "FloatProgressModel",
1179
+ "state": {
1180
+ "_dom_classes": [],
1181
+ "_model_module": "@jupyter-widgets/controls",
1182
+ "_model_module_version": "1.5.0",
1183
+ "_model_name": "FloatProgressModel",
1184
+ "_view_count": null,
1185
+ "_view_module": "@jupyter-widgets/controls",
1186
+ "_view_module_version": "1.5.0",
1187
+ "_view_name": "ProgressView",
1188
+ "bar_style": "success",
1189
+ "description": "",
1190
+ "description_tooltip": null,
1191
+ "layout": "IPY_MODEL_9ca2f9f0ca2f4c368d1589f3daef97b6",
1192
+ "max": 5.0,
1193
+ "min": 0.0,
1194
+ "orientation": "horizontal",
1195
+ "style": "IPY_MODEL_34c26c71e174412d8c6bce15f6cf55ab",
1196
+ "value": 5.0
1197
+ }
1198
+ },
1199
+ "9ca2f9f0ca2f4c368d1589f3daef97b6": {
1200
+ "model_module": "@jupyter-widgets/base",
1201
+ "model_module_version": "1.2.0",
1202
+ "model_name": "LayoutModel",
1203
+ "state": {
1204
+ "_model_module": "@jupyter-widgets/base",
1205
+ "_model_module_version": "1.2.0",
1206
+ "_model_name": "LayoutModel",
1207
+ "_view_count": null,
1208
+ "_view_module": "@jupyter-widgets/base",
1209
+ "_view_module_version": "1.2.0",
1210
+ "_view_name": "LayoutView",
1211
+ "align_content": null,
1212
+ "align_items": null,
1213
+ "align_self": null,
1214
+ "border": null,
1215
+ "bottom": null,
1216
+ "display": null,
1217
+ "flex": null,
1218
+ "flex_flow": null,
1219
+ "grid_area": null,
1220
+ "grid_auto_columns": null,
1221
+ "grid_auto_flow": null,
1222
+ "grid_auto_rows": null,
1223
+ "grid_column": null,
1224
+ "grid_gap": null,
1225
+ "grid_row": null,
1226
+ "grid_template_areas": null,
1227
+ "grid_template_columns": null,
1228
+ "grid_template_rows": null,
1229
+ "height": null,
1230
+ "justify_content": null,
1231
+ "justify_items": null,
1232
+ "left": null,
1233
+ "margin": null,
1234
+ "max_height": null,
1235
+ "max_width": null,
1236
+ "min_height": null,
1237
+ "min_width": null,
1238
+ "object_fit": null,
1239
+ "object_position": null,
1240
+ "order": null,
1241
+ "overflow": null,
1242
+ "overflow_x": null,
1243
+ "overflow_y": null,
1244
+ "padding": null,
1245
+ "right": null,
1246
+ "top": null,
1247
+ "visibility": null,
1248
+ "width": null
1249
+ }
1250
+ },
1251
+ "9d0450740ef245988383b5b43528cb3d": {
1252
+ "model_module": "@jupyter-widgets/controls",
1253
+ "model_module_version": "1.5.0",
1254
+ "model_name": "DescriptionStyleModel",
1255
+ "state": {
1256
+ "_model_module": "@jupyter-widgets/controls",
1257
+ "_model_module_version": "1.5.0",
1258
+ "_model_name": "DescriptionStyleModel",
1259
+ "_view_count": null,
1260
+ "_view_module": "@jupyter-widgets/base",
1261
+ "_view_module_version": "1.2.0",
1262
+ "_view_name": "StyleView",
1263
+ "description_width": ""
1264
+ }
1265
+ },
1266
+ "a6254b23a9df47ec88478882c76e34a1": {
1267
+ "model_module": "@jupyter-widgets/controls",
1268
+ "model_module_version": "1.5.0",
1269
+ "model_name": "HTMLModel",
1270
+ "state": {
1271
+ "_dom_classes": [],
1272
+ "_model_module": "@jupyter-widgets/controls",
1273
+ "_model_module_version": "1.5.0",
1274
+ "_model_name": "HTMLModel",
1275
+ "_view_count": null,
1276
+ "_view_module": "@jupyter-widgets/controls",
1277
+ "_view_module_version": "1.5.0",
1278
+ "_view_name": "HTMLView",
1279
+ "description": "",
1280
+ "description_tooltip": null,
1281
+ "layout": "IPY_MODEL_4bdcfdf8ea294ffa851124a60dd797b5",
1282
+ "placeholder": "​",
1283
+ "style": "IPY_MODEL_86be6094b7ea4c018546c7a08ac21c32",
1284
+ "value": "100%"
1285
+ }
1286
+ },
1287
+ "b38b9d8d5d294e01bef5692bb9f9a086": {
1288
+ "model_module": "@jupyter-widgets/controls",
1289
+ "model_module_version": "1.5.0",
1290
+ "model_name": "HTMLModel",
1291
+ "state": {
1292
+ "_dom_classes": [],
1293
+ "_model_module": "@jupyter-widgets/controls",
1294
+ "_model_module_version": "1.5.0",
1295
+ "_model_name": "HTMLModel",
1296
+ "_view_count": null,
1297
+ "_view_module": "@jupyter-widgets/controls",
1298
+ "_view_module_version": "1.5.0",
1299
+ "_view_name": "HTMLView",
1300
+ "description": "",
1301
+ "description_tooltip": null,
1302
+ "layout": "IPY_MODEL_bf81d5f910ca475798b1bc946f8475b5",
1303
+ "placeholder": "​",
1304
+ "style": "IPY_MODEL_9d0450740ef245988383b5b43528cb3d",
1305
+ "value": " 5/5 [11:07&lt;00:00, 133.98s/it]"
1306
+ }
1307
+ },
1308
+ "b9ccf904ba1f46aaae2cc1c094d45b0b": {
1309
+ "model_module": "@jupyter-widgets/base",
1310
+ "model_module_version": "1.2.0",
1311
+ "model_name": "LayoutModel",
1312
+ "state": {
1313
+ "_model_module": "@jupyter-widgets/base",
1314
+ "_model_module_version": "1.2.0",
1315
+ "_model_name": "LayoutModel",
1316
+ "_view_count": null,
1317
+ "_view_module": "@jupyter-widgets/base",
1318
+ "_view_module_version": "1.2.0",
1319
+ "_view_name": "LayoutView",
1320
+ "align_content": null,
1321
+ "align_items": null,
1322
+ "align_self": null,
1323
+ "border": null,
1324
+ "bottom": null,
1325
+ "display": null,
1326
+ "flex": null,
1327
+ "flex_flow": null,
1328
+ "grid_area": null,
1329
+ "grid_auto_columns": null,
1330
+ "grid_auto_flow": null,
1331
+ "grid_auto_rows": null,
1332
+ "grid_column": null,
1333
+ "grid_gap": null,
1334
+ "grid_row": null,
1335
+ "grid_template_areas": null,
1336
+ "grid_template_columns": null,
1337
+ "grid_template_rows": null,
1338
+ "height": null,
1339
+ "justify_content": null,
1340
+ "justify_items": null,
1341
+ "left": null,
1342
+ "margin": null,
1343
+ "max_height": null,
1344
+ "max_width": null,
1345
+ "min_height": null,
1346
+ "min_width": null,
1347
+ "object_fit": null,
1348
+ "object_position": null,
1349
+ "order": null,
1350
+ "overflow": null,
1351
+ "overflow_x": null,
1352
+ "overflow_y": null,
1353
+ "padding": null,
1354
+ "right": null,
1355
+ "top": null,
1356
+ "visibility": null,
1357
+ "width": null
1358
+ }
1359
+ },
1360
+ "bd3cb9f78122483eb70d16ca6c7b8962": {
1361
+ "model_module": "@jupyter-widgets/controls",
1362
+ "model_module_version": "1.5.0",
1363
+ "model_name": "HBoxModel",
1364
+ "state": {
1365
+ "_dom_classes": [],
1366
+ "_model_module": "@jupyter-widgets/controls",
1367
+ "_model_module_version": "1.5.0",
1368
+ "_model_name": "HBoxModel",
1369
+ "_view_count": null,
1370
+ "_view_module": "@jupyter-widgets/controls",
1371
+ "_view_module_version": "1.5.0",
1372
+ "_view_name": "HBoxView",
1373
+ "box_style": "",
1374
+ "children": [
1375
+ "IPY_MODEL_a6254b23a9df47ec88478882c76e34a1",
1376
+ "IPY_MODEL_92a276fc64f04f4a9220c8ecc22115b2",
1377
+ "IPY_MODEL_b38b9d8d5d294e01bef5692bb9f9a086"
1378
+ ],
1379
+ "layout": "IPY_MODEL_b9ccf904ba1f46aaae2cc1c094d45b0b"
1380
+ }
1381
+ },
1382
+ "bf81d5f910ca475798b1bc946f8475b5": {
1383
+ "model_module": "@jupyter-widgets/base",
1384
+ "model_module_version": "1.2.0",
1385
+ "model_name": "LayoutModel",
1386
+ "state": {
1387
+ "_model_module": "@jupyter-widgets/base",
1388
+ "_model_module_version": "1.2.0",
1389
+ "_model_name": "LayoutModel",
1390
+ "_view_count": null,
1391
+ "_view_module": "@jupyter-widgets/base",
1392
+ "_view_module_version": "1.2.0",
1393
+ "_view_name": "LayoutView",
1394
+ "align_content": null,
1395
+ "align_items": null,
1396
+ "align_self": null,
1397
+ "border": null,
1398
+ "bottom": null,
1399
+ "display": null,
1400
+ "flex": null,
1401
+ "flex_flow": null,
1402
+ "grid_area": null,
1403
+ "grid_auto_columns": null,
1404
+ "grid_auto_flow": null,
1405
+ "grid_auto_rows": null,
1406
+ "grid_column": null,
1407
+ "grid_gap": null,
1408
+ "grid_row": null,
1409
+ "grid_template_areas": null,
1410
+ "grid_template_columns": null,
1411
+ "grid_template_rows": null,
1412
+ "height": null,
1413
+ "justify_content": null,
1414
+ "justify_items": null,
1415
+ "left": null,
1416
+ "margin": null,
1417
+ "max_height": null,
1418
+ "max_width": null,
1419
+ "min_height": null,
1420
+ "min_width": null,
1421
+ "object_fit": null,
1422
+ "object_position": null,
1423
+ "order": null,
1424
+ "overflow": null,
1425
+ "overflow_x": null,
1426
+ "overflow_y": null,
1427
+ "padding": null,
1428
+ "right": null,
1429
+ "top": null,
1430
+ "visibility": null,
1431
+ "width": null
1432
+ }
1433
+ }
1434
+ },
1435
+ "version_major": 2,
1436
+ "version_minor": 0
1437
+ }
1438
+ }
1439
+ },
1440
+ "nbformat": 4,
1441
+ "nbformat_minor": 5
1442
+ }
XgBoost.py/XgBoost_ver1.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
XgBoost.py/XgBoost_ver2.ipynb ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "459944e1",
7
+ "metadata": {
8
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
9
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
10
+ "execution": {
11
+ "iopub.execute_input": "2023-07-16T14:00:17.517706Z",
12
+ "iopub.status.busy": "2023-07-16T14:00:17.517322Z",
13
+ "iopub.status.idle": "2023-07-16T14:00:17.532915Z",
14
+ "shell.execute_reply": "2023-07-16T14:00:17.531918Z"
15
+ },
16
+ "papermill": {
17
+ "duration": 0.028949,
18
+ "end_time": "2023-07-16T14:00:17.535784",
19
+ "exception": false,
20
+ "start_time": "2023-07-16T14:00:17.506835",
21
+ "status": "completed"
22
+ },
23
+ "tags": []
24
+ },
25
+ "outputs": [
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv\n",
31
+ "/kaggle/input/icr-identify-age-related-conditions/greeks.csv\n",
32
+ "/kaggle/input/icr-identify-age-related-conditions/train.csv\n",
33
+ "/kaggle/input/icr-identify-age-related-conditions/test.csv\n"
34
+ ]
35
+ }
36
+ ],
37
+ "source": [
38
+ "# This Python 3 environment comes with many helpful analytics libraries installed\n",
39
+ "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
40
+ "# For example, here's several helpful packages to load\n",
41
+ "\n",
42
+ "import numpy as np # linear algebra\n",
43
+ "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
44
+ "\n",
45
+ "# Input data files are available in the read-only \"../input/\" directory\n",
46
+ "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
47
+ "\n",
48
+ "import os\n",
49
+ "for dirname, _, filenames in os.walk('/kaggle/input'):\n",
50
+ " for filename in filenames:\n",
51
+ " print(os.path.join(dirname, filename))\n",
52
+ "\n",
53
+ "# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
54
+ "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 2,
60
+ "id": "6f1fa358",
61
+ "metadata": {
62
+ "execution": {
63
+ "iopub.execute_input": "2023-07-16T14:00:17.554886Z",
64
+ "iopub.status.busy": "2023-07-16T14:00:17.553876Z",
65
+ "iopub.status.idle": "2023-07-16T14:00:21.000320Z",
66
+ "shell.execute_reply": "2023-07-16T14:00:20.998572Z"
67
+ },
68
+ "papermill": {
69
+ "duration": 3.459261,
70
+ "end_time": "2023-07-16T14:00:21.003513",
71
+ "exception": false,
72
+ "start_time": "2023-07-16T14:00:17.544252",
73
+ "status": "completed"
74
+ },
75
+ "tags": []
76
+ },
77
+ "outputs": [
78
+ {
79
+ "name": "stderr",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
83
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
84
+ ]
85
+ }
86
+ ],
87
+ "source": [
88
+ "import numpy as np\n",
89
+ "import pandas as pd\n",
90
+ "from sklearn.preprocessing import LabelEncoder,normalize\n",
91
+ "from sklearn.ensemble import GradientBoostingClassifier,RandomForestClassifier\n",
92
+ "from sklearn.metrics import accuracy_score\n",
93
+ "from sklearn.impute import SimpleImputer\n",
94
+ "import imblearn\n",
95
+ "from imblearn.over_sampling import RandomOverSampler\n",
96
+ "from imblearn.under_sampling import RandomUnderSampler\n",
97
+ "import xgboost\n",
98
+ "import inspect\n",
99
+ "from collections import defaultdict\n",
100
+ "#from tabpfn import TabPFNClassifier\n",
101
+ "import lightgbm as lgb\n",
102
+ "import warnings\n",
103
+ "warnings.filterwarnings('ignore')\n",
104
+ "from sklearn.model_selection import KFold as KF, GridSearchCV\n",
105
+ "from tqdm.notebook import tqdm\n",
106
+ "from datetime import datetime"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 3,
112
+ "id": "edcd8543",
113
+ "metadata": {
114
+ "execution": {
115
+ "iopub.execute_input": "2023-07-16T14:00:21.023332Z",
116
+ "iopub.status.busy": "2023-07-16T14:00:21.022856Z",
117
+ "iopub.status.idle": "2023-07-16T14:00:21.088898Z",
118
+ "shell.execute_reply": "2023-07-16T14:00:21.087472Z"
119
+ },
120
+ "papermill": {
121
+ "duration": 0.079354,
122
+ "end_time": "2023-07-16T14:00:21.092191",
123
+ "exception": false,
124
+ "start_time": "2023-07-16T14:00:21.012837",
125
+ "status": "completed"
126
+ },
127
+ "tags": []
128
+ },
129
+ "outputs": [],
130
+ "source": [
131
+ "\n",
132
+ "train_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/train.csv\")\n",
133
+ "test_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/test.csv\")\n",
134
+ "greeks_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/greeks.csv\")\n",
135
+ "sample_submission = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv\")"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 4,
141
+ "id": "4da67da3",
142
+ "metadata": {
143
+ "execution": {
144
+ "iopub.execute_input": "2023-07-16T14:00:21.111296Z",
145
+ "iopub.status.busy": "2023-07-16T14:00:21.110878Z",
146
+ "iopub.status.idle": "2023-07-16T14:00:21.126219Z",
147
+ "shell.execute_reply": "2023-07-16T14:00:21.125017Z"
148
+ },
149
+ "papermill": {
150
+ "duration": 0.028324,
151
+ "end_time": "2023-07-16T14:00:21.129087",
152
+ "exception": false,
153
+ "start_time": "2023-07-16T14:00:21.100763",
154
+ "status": "completed"
155
+ },
156
+ "tags": []
157
+ },
158
+ "outputs": [],
159
+ "source": [
160
+ "\n",
161
+ "first_category = train_df.EJ.unique()[0]\n",
162
+ "train_df.EJ = train_df.EJ.eq(first_category).astype('int')\n",
163
+ "test_df.EJ = test_df.EJ.eq(first_category).astype('int')"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 5,
169
+ "id": "5d460a95",
170
+ "metadata": {
171
+ "execution": {
172
+ "iopub.execute_input": "2023-07-16T14:00:21.148826Z",
173
+ "iopub.status.busy": "2023-07-16T14:00:21.148442Z",
174
+ "iopub.status.idle": "2023-07-16T14:00:21.155072Z",
175
+ "shell.execute_reply": "2023-07-16T14:00:21.153937Z"
176
+ },
177
+ "papermill": {
178
+ "duration": 0.020215,
179
+ "end_time": "2023-07-16T14:00:21.157532",
180
+ "exception": false,
181
+ "start_time": "2023-07-16T14:00:21.137317",
182
+ "status": "completed"
183
+ },
184
+ "tags": []
185
+ },
186
+ "outputs": [],
187
+ "source": [
188
+ "\n",
189
+ "def random_under_sampler(df):\n",
190
+ " neg, pos = np.bincount(df['Class'])\n",
191
+ " one_df = df.loc[df['Class'] == 1]\n",
192
+ " zero_df = df.loc[df['Class'] == 0]\n",
193
+ " zero_df = zero_df.sample(n=pos)\n",
194
+ " undersampled_df = pd.concat([zero_df, one_df])\n",
195
+ " return undersampled_df.sample(frac = 1)\n"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": 6,
201
+ "id": "20ce7205",
202
+ "metadata": {
203
+ "execution": {
204
+ "iopub.execute_input": "2023-07-16T14:00:21.175607Z",
205
+ "iopub.status.busy": "2023-07-16T14:00:21.174813Z",
206
+ "iopub.status.idle": "2023-07-16T14:00:21.188987Z",
207
+ "shell.execute_reply": "2023-07-16T14:00:21.187603Z"
208
+ },
209
+ "papermill": {
210
+ "duration": 0.0262,
211
+ "end_time": "2023-07-16T14:00:21.191636",
212
+ "exception": false,
213
+ "start_time": "2023-07-16T14:00:21.165436",
214
+ "status": "completed"
215
+ },
216
+ "tags": []
217
+ },
218
+ "outputs": [],
219
+ "source": [
220
+ "train_df_good = random_under_sampler(train_df)"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": 7,
226
+ "id": "f5cb2e31",
227
+ "metadata": {
228
+ "execution": {
229
+ "iopub.execute_input": "2023-07-16T14:00:21.210828Z",
230
+ "iopub.status.busy": "2023-07-16T14:00:21.209550Z",
231
+ "iopub.status.idle": "2023-07-16T14:00:21.219108Z",
232
+ "shell.execute_reply": "2023-07-16T14:00:21.217884Z"
233
+ },
234
+ "papermill": {
235
+ "duration": 0.021809,
236
+ "end_time": "2023-07-16T14:00:21.221645",
237
+ "exception": false,
238
+ "start_time": "2023-07-16T14:00:21.199836",
239
+ "status": "completed"
240
+ },
241
+ "tags": []
242
+ },
243
+ "outputs": [
244
+ {
245
+ "data": {
246
+ "text/plain": [
247
+ "(216, 58)"
248
+ ]
249
+ },
250
+ "execution_count": 7,
251
+ "metadata": {},
252
+ "output_type": "execute_result"
253
+ }
254
+ ],
255
+ "source": [
256
+ "\n",
257
+ "train_df_good.shape"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": 8,
263
+ "id": "6b3c6991",
264
+ "metadata": {
265
+ "execution": {
266
+ "iopub.execute_input": "2023-07-16T14:00:21.241474Z",
267
+ "iopub.status.busy": "2023-07-16T14:00:21.241030Z",
268
+ "iopub.status.idle": "2023-07-16T14:00:21.251174Z",
269
+ "shell.execute_reply": "2023-07-16T14:00:21.249737Z"
270
+ },
271
+ "papermill": {
272
+ "duration": 0.02339,
273
+ "end_time": "2023-07-16T14:00:21.253719",
274
+ "exception": false,
275
+ "start_time": "2023-07-16T14:00:21.230329",
276
+ "status": "completed"
277
+ },
278
+ "tags": []
279
+ },
280
+ "outputs": [],
281
+ "source": [
282
+ "predictor_columns = [n for n in train_df.columns if n != 'Class' and n != 'Id']\n",
283
+ "x= train_df[predictor_columns]\n",
284
+ "y = train_df['Class']"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": 9,
290
+ "id": "7bb0f756",
291
+ "metadata": {
292
+ "execution": {
293
+ "iopub.execute_input": "2023-07-16T14:00:21.273356Z",
294
+ "iopub.status.busy": "2023-07-16T14:00:21.272869Z",
295
+ "iopub.status.idle": "2023-07-16T14:00:21.278998Z",
296
+ "shell.execute_reply": "2023-07-16T14:00:21.277841Z"
297
+ },
298
+ "papermill": {
299
+ "duration": 0.018676,
300
+ "end_time": "2023-07-16T14:00:21.281241",
301
+ "exception": false,
302
+ "start_time": "2023-07-16T14:00:21.262565",
303
+ "status": "completed"
304
+ },
305
+ "tags": []
306
+ },
307
+ "outputs": [],
308
+ "source": [
309
+ "cv_outer = KF(n_splits = 10, shuffle=True, random_state=42)\n",
310
+ "cv_inner = KF(n_splits = 5, shuffle=True, random_state=42)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 10,
316
+ "id": "2cf318d7",
317
+ "metadata": {
318
+ "execution": {
319
+ "iopub.execute_input": "2023-07-16T14:00:21.300708Z",
320
+ "iopub.status.busy": "2023-07-16T14:00:21.299943Z",
321
+ "iopub.status.idle": "2023-07-16T14:00:21.308230Z",
322
+ "shell.execute_reply": "2023-07-16T14:00:21.307088Z"
323
+ },
324
+ "papermill": {
325
+ "duration": 0.020883,
326
+ "end_time": "2023-07-16T14:00:21.310917",
327
+ "exception": false,
328
+ "start_time": "2023-07-16T14:00:21.290034",
329
+ "status": "completed"
330
+ },
331
+ "tags": []
332
+ },
333
+ "outputs": [],
334
+ "source": [
335
+ "def balanced_log_loss(y_true, y_pred):\n",
336
+ " N_0 = np.sum(1 - y_true)\n",
337
+ " N_1 = np.sum(y_true)\n",
338
+ " \n",
339
+ " w_0 = 1 / N_0\n",
340
+ " w_1 = 1 / N_1\n",
341
+ " p_1 = np.clip(y_pred, 1e-15, 1 - 1e-15)\n",
342
+ " p_0 = 1 - p_1\n",
343
+ " log_loss_0 = -np.sum((1 - y_true) * np.log(p_0))\n",
344
+ " log_loss_1 = -np.sum(y_true * np.log(p_1))\n",
345
+ " balanced_log_loss = 2*(w_0 * log_loss_0 + w_1 * log_loss_1) / (w_0 + w_1)\n",
346
+ " return balanced_log_loss/(N_0+N_1)"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 11,
352
+ "id": "e21b0bd7",
353
+ "metadata": {
354
+ "execution": {
355
+ "iopub.execute_input": "2023-07-16T14:00:21.331052Z",
356
+ "iopub.status.busy": "2023-07-16T14:00:21.330612Z",
357
+ "iopub.status.idle": "2023-07-16T14:00:21.343946Z",
358
+ "shell.execute_reply": "2023-07-16T14:00:21.343016Z"
359
+ },
360
+ "papermill": {
361
+ "duration": 0.026312,
362
+ "end_time": "2023-07-16T14:00:21.346767",
363
+ "exception": false,
364
+ "start_time": "2023-07-16T14:00:21.320455",
365
+ "status": "completed"
366
+ },
367
+ "tags": []
368
+ },
369
+ "outputs": [],
370
+ "source": [
371
+ "class Ensemble():\n",
372
+ " def __init__(self):\n",
373
+ " self.imputer = SimpleImputer(missing_values=np.nan, strategy='median')\n",
374
+ " self.classifiers =[xgboost.XGBClassifier(n_estimators=100,max_depth=3,learning_rate=0.2,subsample=0.9,colsample_bytree=0.85),\n",
375
+ " xgboost.XGBClassifier()]\n",
376
+ " #TabPFNClassifier(N_ensemble_configurations=128),\n",
377
+ " #TabPFNClassifier(N_ensemble_configurations=48)]\n",
378
+ " \n",
379
+ " def fit(self,X,y):\n",
380
+ " y = y.values\n",
381
+ " unique_classes, y = np.unique(y, return_inverse=True)\n",
382
+ " self.classes_ = unique_classes\n",
383
+ " first_category = X.EJ.unique()[0]\n",
384
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n",
385
+ " X = self.imputer.fit_transform(X)\n",
386
+ "\n",
387
+ " for classifier in self.classifiers:\n",
388
+ " #if classifier==self.classifiers[2] or classifier==self.classifiers[3]:\n",
389
+ " # classifier.fit(X,y,overwrite_warning =True)\n",
390
+ " # else :\n",
391
+ " classifier.fit(X, y)\n",
392
+ " \n",
393
+ " def predict_proba(self, x):\n",
394
+ " x = self.imputer.transform(x)\n",
395
+ " probabilities = np.stack([classifier.predict_proba(x) for classifier in self.classifiers])\n",
396
+ " averaged_probabilities = np.mean(probabilities, axis=0)\n",
397
+ " class_0_est_instances = averaged_probabilities[:, 0].sum()\n",
398
+ " others_est_instances = averaged_probabilities[:, 1:].sum()\n",
399
+ " # Weighted probabilities based on class imbalance\n",
400
+ " new_probabilities = averaged_probabilities * np.array([[1/(class_0_est_instances if i==0 else others_est_instances) for i in range(averaged_probabilities.shape[1])]])\n",
401
+ " return new_probabilities / np.sum(new_probabilities, axis=1, keepdims=1) \n"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": 12,
407
+ "id": "38907581",
408
+ "metadata": {
409
+ "execution": {
410
+ "iopub.execute_input": "2023-07-16T14:00:21.367308Z",
411
+ "iopub.status.busy": "2023-07-16T14:00:21.366896Z",
412
+ "iopub.status.idle": "2023-07-16T14:00:21.379770Z",
413
+ "shell.execute_reply": "2023-07-16T14:00:21.378370Z"
414
+ },
415
+ "papermill": {
416
+ "duration": 0.026467,
417
+ "end_time": "2023-07-16T14:00:21.382714",
418
+ "exception": false,
419
+ "start_time": "2023-07-16T14:00:21.356247",
420
+ "status": "completed"
421
+ },
422
+ "tags": []
423
+ },
424
+ "outputs": [],
425
+ "source": [
426
+ "def training(model, x,y,y_meta):\n",
427
+ " outer_results = list()\n",
428
+ " best_loss = np.inf\n",
429
+ " split = 0\n",
430
+ " splits = 5\n",
431
+ " for train_idx,val_idx in tqdm(cv_inner.split(x), total = splits):\n",
432
+ " split+=1\n",
433
+ " x_train, x_val = x.iloc[train_idx],x.iloc[val_idx]\n",
434
+ " y_train, y_val = y_meta.iloc[train_idx], y.iloc[val_idx]\n",
435
+ " \n",
436
+ " model.fit(x_train, y_train)\n",
437
+ " y_pred = model.predict_proba(x_val)\n",
438
+ " probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
439
+ " p0 = probabilities[:,:1]\n",
440
+ " p0[p0 > 0.86] = 1\n",
441
+ " p0[p0 < 0.14] = 0\n",
442
+ " y_p = np.empty((y_pred.shape[0],))\n",
443
+ " for i in range(y_pred.shape[0]):\n",
444
+ " if p0[i]>=0.5:\n",
445
+ " y_p[i]= False\n",
446
+ " else :\n",
447
+ " y_p[i]=True\n",
448
+ " y_p = y_p.astype(int)\n",
449
+ " loss = balanced_log_loss(y_val,y_p)\n",
450
+ "\n",
451
+ " if loss<best_loss:\n",
452
+ " best_model = model\n",
453
+ " best_loss = loss\n",
454
+ " print('best_model_saved')\n",
455
+ " outer_results.append(loss)\n",
456
+ " print('>val_loss=%.5f, split = %.1f' % (loss,split))\n",
457
+ " print('LOSS: %.5f' % (np.mean(outer_results)))\n",
458
+ " return best_model"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 13,
464
+ "id": "5783c385",
465
+ "metadata": {
466
+ "execution": {
467
+ "iopub.execute_input": "2023-07-16T14:00:21.401915Z",
468
+ "iopub.status.busy": "2023-07-16T14:00:21.401467Z",
469
+ "iopub.status.idle": "2023-07-16T14:00:21.425585Z",
470
+ "shell.execute_reply": "2023-07-16T14:00:21.424538Z"
471
+ },
472
+ "papermill": {
473
+ "duration": 0.036977,
474
+ "end_time": "2023-07-16T14:00:21.428411",
475
+ "exception": false,
476
+ "start_time": "2023-07-16T14:00:21.391434",
477
+ "status": "completed"
478
+ },
479
+ "tags": []
480
+ },
481
+ "outputs": [],
482
+ "source": [
483
+ "times = greeks_df.Epsilon.copy()\n",
484
+ "times[greeks_df.Epsilon != 'Unknown'] = greeks_df.Epsilon[greeks_df.Epsilon != 'Unknown'].map(lambda x: datetime.strptime(x,'%m/%d/%Y').toordinal())\n",
485
+ "times[greeks_df.Epsilon == 'Unknown'] = np.nan"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 14,
491
+ "id": "b0264cd3",
492
+ "metadata": {
493
+ "execution": {
494
+ "iopub.execute_input": "2023-07-16T14:00:21.449211Z",
495
+ "iopub.status.busy": "2023-07-16T14:00:21.448215Z",
496
+ "iopub.status.idle": "2023-07-16T14:00:21.463314Z",
497
+ "shell.execute_reply": "2023-07-16T14:00:21.462240Z"
498
+ },
499
+ "papermill": {
500
+ "duration": 0.028401,
501
+ "end_time": "2023-07-16T14:00:21.466328",
502
+ "exception": false,
503
+ "start_time": "2023-07-16T14:00:21.437927",
504
+ "status": "completed"
505
+ },
506
+ "tags": []
507
+ },
508
+ "outputs": [],
509
+ "source": [
510
+ "train_pred_and_time = pd.concat((train_df, times), axis=1)\n",
511
+ "test_predictors = test_df[predictor_columns]\n",
512
+ "first_category = test_predictors.EJ.unique()[0]\n",
513
+ "test_predictors.EJ = test_predictors.EJ.eq(first_category).astype('int')\n",
514
+ "test_pred_and_time = np.concatenate((test_predictors, np.zeros((len(test_predictors), 1)) + train_pred_and_time.Epsilon.max() + 1), axis=1)"
515
+ ]
516
+ },
517
+ {
518
+ "cell_type": "code",
519
+ "execution_count": 15,
520
+ "id": "9c0d5363",
521
+ "metadata": {
522
+ "execution": {
523
+ "iopub.execute_input": "2023-07-16T14:00:21.487186Z",
524
+ "iopub.status.busy": "2023-07-16T14:00:21.486384Z",
525
+ "iopub.status.idle": "2023-07-16T14:00:21.530196Z",
526
+ "shell.execute_reply": "2023-07-16T14:00:21.528975Z"
527
+ },
528
+ "papermill": {
529
+ "duration": 0.057986,
530
+ "end_time": "2023-07-16T14:00:21.533007",
531
+ "exception": false,
532
+ "start_time": "2023-07-16T14:00:21.475021",
533
+ "status": "completed"
534
+ },
535
+ "tags": []
536
+ },
537
+ "outputs": [
538
+ {
539
+ "name": "stdout",
540
+ "output_type": "stream",
541
+ "text": [
542
+ "Original dataset shape\n",
543
+ "A 509\n",
544
+ "B 61\n",
545
+ "G 29\n",
546
+ "D 18\n",
547
+ "Name: Alpha, dtype: int64\n",
548
+ "Resample dataset shape\n",
549
+ "B 509\n",
550
+ "A 509\n",
551
+ "D 509\n",
552
+ "G 509\n",
553
+ "Name: Alpha, dtype: int64\n"
554
+ ]
555
+ }
556
+ ],
557
+ "source": [
558
+ "ros = RandomOverSampler(random_state=42)\n",
559
+ "train_ros, y_ros = ros.fit_resample(train_pred_and_time, greeks_df.Alpha)\n",
560
+ "print('Original dataset shape')\n",
561
+ "print(greeks_df.Alpha.value_counts())\n",
562
+ "print('Resample dataset shape')\n",
563
+ "print( y_ros.value_counts())"
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "code",
568
+ "execution_count": 16,
569
+ "id": "e375c3b6",
570
+ "metadata": {
571
+ "execution": {
572
+ "iopub.execute_input": "2023-07-16T14:00:21.552682Z",
573
+ "iopub.status.busy": "2023-07-16T14:00:21.551996Z",
574
+ "iopub.status.idle": "2023-07-16T14:00:21.560722Z",
575
+ "shell.execute_reply": "2023-07-16T14:00:21.559571Z"
576
+ },
577
+ "papermill": {
578
+ "duration": 0.021784,
579
+ "end_time": "2023-07-16T14:00:21.563678",
580
+ "exception": false,
581
+ "start_time": "2023-07-16T14:00:21.541894",
582
+ "status": "completed"
583
+ },
584
+ "tags": []
585
+ },
586
+ "outputs": [],
587
+ "source": [
588
+ "\n",
589
+ "x_ros = train_ros.drop(['Class', 'Id'],axis=1)\n",
590
+ "y_ = train_ros.Class"
591
+ ]
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": 17,
596
+ "id": "cdeca513",
597
+ "metadata": {
598
+ "execution": {
599
+ "iopub.execute_input": "2023-07-16T14:00:21.582486Z",
600
+ "iopub.status.busy": "2023-07-16T14:00:21.582001Z",
601
+ "iopub.status.idle": "2023-07-16T14:00:21.587344Z",
602
+ "shell.execute_reply": "2023-07-16T14:00:21.586235Z"
603
+ },
604
+ "papermill": {
605
+ "duration": 0.017561,
606
+ "end_time": "2023-07-16T14:00:21.589728",
607
+ "exception": false,
608
+ "start_time": "2023-07-16T14:00:21.572167",
609
+ "status": "completed"
610
+ },
611
+ "tags": []
612
+ },
613
+ "outputs": [],
614
+ "source": [
615
+ "\n",
616
+ "yt = Ensemble()"
617
+ ]
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "execution_count": 18,
622
+ "id": "01637483",
623
+ "metadata": {
624
+ "execution": {
625
+ "iopub.execute_input": "2023-07-16T14:00:21.608991Z",
626
+ "iopub.status.busy": "2023-07-16T14:00:21.608586Z",
627
+ "iopub.status.idle": "2023-07-16T14:00:29.681349Z",
628
+ "shell.execute_reply": "2023-07-16T14:00:29.680406Z"
629
+ },
630
+ "papermill": {
631
+ "duration": 8.085504,
632
+ "end_time": "2023-07-16T14:00:29.683798",
633
+ "exception": false,
634
+ "start_time": "2023-07-16T14:00:21.598294",
635
+ "status": "completed"
636
+ },
637
+ "tags": []
638
+ },
639
+ "outputs": [
640
+ {
641
+ "data": {
642
+ "application/vnd.jupyter.widget-view+json": {
643
+ "model_id": "18b1b7454e0e48319b05b5a993f47cbd",
644
+ "version_major": 2,
645
+ "version_minor": 0
646
+ },
647
+ "text/plain": [
648
+ " 0%| | 0/5 [00:00<?, ?it/s]"
649
+ ]
650
+ },
651
+ "metadata": {},
652
+ "output_type": "display_data"
653
+ },
654
+ {
655
+ "name": "stdout",
656
+ "output_type": "stream",
657
+ "text": [
658
+ "best_model_saved\n",
659
+ ">val_loss=0.49134, split = 1.0\n",
660
+ "best_model_saved\n",
661
+ ">val_loss=0.24771, split = 2.0\n",
662
+ "best_model_saved\n",
663
+ ">val_loss=0.00000, split = 3.0\n",
664
+ ">val_loss=0.13220, split = 4.0\n",
665
+ ">val_loss=0.40159, split = 5.0\n",
666
+ "LOSS: 0.25457\n"
667
+ ]
668
+ }
669
+ ],
670
+ "source": [
671
+ "\n",
672
+ "m = training(yt,x_ros,y_,y_ros)"
673
+ ]
674
+ },
675
+ {
676
+ "cell_type": "code",
677
+ "execution_count": 19,
678
+ "id": "52603914",
679
+ "metadata": {
680
+ "execution": {
681
+ "iopub.execute_input": "2023-07-16T14:00:29.705219Z",
682
+ "iopub.status.busy": "2023-07-16T14:00:29.704713Z",
683
+ "iopub.status.idle": "2023-07-16T14:00:29.715873Z",
684
+ "shell.execute_reply": "2023-07-16T14:00:29.714582Z"
685
+ },
686
+ "papermill": {
687
+ "duration": 0.025774,
688
+ "end_time": "2023-07-16T14:00:29.718965",
689
+ "exception": false,
690
+ "start_time": "2023-07-16T14:00:29.693191",
691
+ "status": "completed"
692
+ },
693
+ "tags": []
694
+ },
695
+ "outputs": [
696
+ {
697
+ "data": {
698
+ "text/plain": [
699
+ "1 0.75\n",
700
+ "0 0.25\n",
701
+ "Name: Class, dtype: float64"
702
+ ]
703
+ },
704
+ "execution_count": 19,
705
+ "metadata": {},
706
+ "output_type": "execute_result"
707
+ }
708
+ ],
709
+ "source": [
710
+ "y_.value_counts()/y_.shape[0]"
711
+ ]
712
+ },
713
+ {
714
+ "cell_type": "code",
715
+ "execution_count": 20,
716
+ "id": "754e9c23",
717
+ "metadata": {
718
+ "execution": {
719
+ "iopub.execute_input": "2023-07-16T14:00:29.740372Z",
720
+ "iopub.status.busy": "2023-07-16T14:00:29.739879Z",
721
+ "iopub.status.idle": "2023-07-16T14:00:29.752997Z",
722
+ "shell.execute_reply": "2023-07-16T14:00:29.752086Z"
723
+ },
724
+ "papermill": {
725
+ "duration": 0.027082,
726
+ "end_time": "2023-07-16T14:00:29.755427",
727
+ "exception": false,
728
+ "start_time": "2023-07-16T14:00:29.728345",
729
+ "status": "completed"
730
+ },
731
+ "tags": []
732
+ },
733
+ "outputs": [],
734
+ "source": [
735
+ "\n",
736
+ "y_pred = m.predict_proba(test_pred_and_time)\n",
737
+ "probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
738
+ "p0 = probabilities[:,:1]\n",
739
+ "p0[p0 > 0.59] = 1\n",
740
+ "p0[p0 < 0.28] = 0"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "execution_count": 21,
746
+ "id": "b6505a78",
747
+ "metadata": {
748
+ "execution": {
749
+ "iopub.execute_input": "2023-07-16T14:00:29.777401Z",
750
+ "iopub.status.busy": "2023-07-16T14:00:29.776987Z",
751
+ "iopub.status.idle": "2023-07-16T14:00:29.790086Z",
752
+ "shell.execute_reply": "2023-07-16T14:00:29.788839Z"
753
+ },
754
+ "papermill": {
755
+ "duration": 0.026657,
756
+ "end_time": "2023-07-16T14:00:29.793006",
757
+ "exception": false,
758
+ "start_time": "2023-07-16T14:00:29.766349",
759
+ "status": "completed"
760
+ },
761
+ "tags": []
762
+ },
763
+ "outputs": [],
764
+ "source": [
765
+ "\n",
766
+ "submission = pd.DataFrame(test_df[\"Id\"], columns=[\"Id\"])\n",
767
+ "submission[\"class_0\"] = p0\n",
768
+ "submission[\"class_1\"] = 1 - p0\n",
769
+ "submission.to_csv('submission.csv', index=False)"
770
+ ]
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "execution_count": 22,
775
+ "id": "a185d0b6",
776
+ "metadata": {
777
+ "execution": {
778
+ "iopub.execute_input": "2023-07-16T14:00:29.812749Z",
779
+ "iopub.status.busy": "2023-07-16T14:00:29.812370Z",
780
+ "iopub.status.idle": "2023-07-16T14:00:29.832503Z",
781
+ "shell.execute_reply": "2023-07-16T14:00:29.831080Z"
782
+ },
783
+ "papermill": {
784
+ "duration": 0.033126,
785
+ "end_time": "2023-07-16T14:00:29.835328",
786
+ "exception": false,
787
+ "start_time": "2023-07-16T14:00:29.802202",
788
+ "status": "completed"
789
+ },
790
+ "tags": []
791
+ },
792
+ "outputs": [
793
+ {
794
+ "data": {
795
+ "text/html": [
796
+ "<div>\n",
797
+ "<style scoped>\n",
798
+ " .dataframe tbody tr th:only-of-type {\n",
799
+ " vertical-align: middle;\n",
800
+ " }\n",
801
+ "\n",
802
+ " .dataframe tbody tr th {\n",
803
+ " vertical-align: top;\n",
804
+ " }\n",
805
+ "\n",
806
+ " .dataframe thead th {\n",
807
+ " text-align: right;\n",
808
+ " }\n",
809
+ "</style>\n",
810
+ "<table border=\"1\" class=\"dataframe\">\n",
811
+ " <thead>\n",
812
+ " <tr style=\"text-align: right;\">\n",
813
+ " <th></th>\n",
814
+ " <th>Id</th>\n",
815
+ " <th>class_0</th>\n",
816
+ " <th>class_1</th>\n",
817
+ " </tr>\n",
818
+ " </thead>\n",
819
+ " <tbody>\n",
820
+ " <tr>\n",
821
+ " <th>0</th>\n",
822
+ " <td>00eed32682bb</td>\n",
823
+ " <td>0.5</td>\n",
824
+ " <td>0.5</td>\n",
825
+ " </tr>\n",
826
+ " <tr>\n",
827
+ " <th>1</th>\n",
828
+ " <td>010ebe33f668</td>\n",
829
+ " <td>0.5</td>\n",
830
+ " <td>0.5</td>\n",
831
+ " </tr>\n",
832
+ " <tr>\n",
833
+ " <th>2</th>\n",
834
+ " <td>02fa521e1838</td>\n",
835
+ " <td>0.5</td>\n",
836
+ " <td>0.5</td>\n",
837
+ " </tr>\n",
838
+ " <tr>\n",
839
+ " <th>3</th>\n",
840
+ " <td>040e15f562a2</td>\n",
841
+ " <td>0.5</td>\n",
842
+ " <td>0.5</td>\n",
843
+ " </tr>\n",
844
+ " <tr>\n",
845
+ " <th>4</th>\n",
846
+ " <td>046e85c7cc7f</td>\n",
847
+ " <td>0.5</td>\n",
848
+ " <td>0.5</td>\n",
849
+ " </tr>\n",
850
+ " </tbody>\n",
851
+ "</table>\n",
852
+ "</div>"
853
+ ],
854
+ "text/plain": [
855
+ " Id class_0 class_1\n",
856
+ "0 00eed32682bb 0.5 0.5\n",
857
+ "1 010ebe33f668 0.5 0.5\n",
858
+ "2 02fa521e1838 0.5 0.5\n",
859
+ "3 040e15f562a2 0.5 0.5\n",
860
+ "4 046e85c7cc7f 0.5 0.5"
861
+ ]
862
+ },
863
+ "execution_count": 22,
864
+ "metadata": {},
865
+ "output_type": "execute_result"
866
+ }
867
+ ],
868
+ "source": [
869
+ "submission_df = pd.read_csv('submission.csv')\n",
870
+ "submission_df"
871
+ ]
872
+ }
873
+ ],
874
+ "metadata": {
875
+ "kernelspec": {
876
+ "display_name": "Python 3",
877
+ "language": "python",
878
+ "name": "python3"
879
+ },
880
+ "language_info": {
881
+ "codemirror_mode": {
882
+ "name": "ipython",
883
+ "version": 3
884
+ },
885
+ "file_extension": ".py",
886
+ "mimetype": "text/x-python",
887
+ "name": "python",
888
+ "nbconvert_exporter": "python",
889
+ "pygments_lexer": "ipython3",
890
+ "version": "3.10.12"
891
+ },
892
+ "papermill": {
893
+ "default_parameters": {},
894
+ "duration": 25.292095,
895
+ "end_time": "2023-07-16T14:00:30.969842",
896
+ "environment_variables": {},
897
+ "exception": null,
898
+ "input_path": "__notebook__.ipynb",
899
+ "output_path": "__notebook__.ipynb",
900
+ "parameters": {},
901
+ "start_time": "2023-07-16T14:00:05.677747",
902
+ "version": "2.4.0"
903
+ },
904
+ "widgets": {
905
+ "application/vnd.jupyter.widget-state+json": {
906
+ "state": {
907
+ "016218d199874567ad89e22e2df78d64": {
908
+ "model_module": "@jupyter-widgets/controls",
909
+ "model_module_version": "1.5.0",
910
+ "model_name": "DescriptionStyleModel",
911
+ "state": {
912
+ "_model_module": "@jupyter-widgets/controls",
913
+ "_model_module_version": "1.5.0",
914
+ "_model_name": "DescriptionStyleModel",
915
+ "_view_count": null,
916
+ "_view_module": "@jupyter-widgets/base",
917
+ "_view_module_version": "1.2.0",
918
+ "_view_name": "StyleView",
919
+ "description_width": ""
920
+ }
921
+ },
922
+ "03bdf5df19ba49cd96e4b5209e92c6ab": {
923
+ "model_module": "@jupyter-widgets/controls",
924
+ "model_module_version": "1.5.0",
925
+ "model_name": "ProgressStyleModel",
926
+ "state": {
927
+ "_model_module": "@jupyter-widgets/controls",
928
+ "_model_module_version": "1.5.0",
929
+ "_model_name": "ProgressStyleModel",
930
+ "_view_count": null,
931
+ "_view_module": "@jupyter-widgets/base",
932
+ "_view_module_version": "1.2.0",
933
+ "_view_name": "StyleView",
934
+ "bar_color": null,
935
+ "description_width": ""
936
+ }
937
+ },
938
+ "03e57d986f05468b899b8030b167a552": {
939
+ "model_module": "@jupyter-widgets/base",
940
+ "model_module_version": "1.2.0",
941
+ "model_name": "LayoutModel",
942
+ "state": {
943
+ "_model_module": "@jupyter-widgets/base",
944
+ "_model_module_version": "1.2.0",
945
+ "_model_name": "LayoutModel",
946
+ "_view_count": null,
947
+ "_view_module": "@jupyter-widgets/base",
948
+ "_view_module_version": "1.2.0",
949
+ "_view_name": "LayoutView",
950
+ "align_content": null,
951
+ "align_items": null,
952
+ "align_self": null,
953
+ "border": null,
954
+ "bottom": null,
955
+ "display": null,
956
+ "flex": null,
957
+ "flex_flow": null,
958
+ "grid_area": null,
959
+ "grid_auto_columns": null,
960
+ "grid_auto_flow": null,
961
+ "grid_auto_rows": null,
962
+ "grid_column": null,
963
+ "grid_gap": null,
964
+ "grid_row": null,
965
+ "grid_template_areas": null,
966
+ "grid_template_columns": null,
967
+ "grid_template_rows": null,
968
+ "height": null,
969
+ "justify_content": null,
970
+ "justify_items": null,
971
+ "left": null,
972
+ "margin": null,
973
+ "max_height": null,
974
+ "max_width": null,
975
+ "min_height": null,
976
+ "min_width": null,
977
+ "object_fit": null,
978
+ "object_position": null,
979
+ "order": null,
980
+ "overflow": null,
981
+ "overflow_x": null,
982
+ "overflow_y": null,
983
+ "padding": null,
984
+ "right": null,
985
+ "top": null,
986
+ "visibility": null,
987
+ "width": null
988
+ }
989
+ },
990
+ "13642f0614b145fd9bb3043dac743a62": {
991
+ "model_module": "@jupyter-widgets/base",
992
+ "model_module_version": "1.2.0",
993
+ "model_name": "LayoutModel",
994
+ "state": {
995
+ "_model_module": "@jupyter-widgets/base",
996
+ "_model_module_version": "1.2.0",
997
+ "_model_name": "LayoutModel",
998
+ "_view_count": null,
999
+ "_view_module": "@jupyter-widgets/base",
1000
+ "_view_module_version": "1.2.0",
1001
+ "_view_name": "LayoutView",
1002
+ "align_content": null,
1003
+ "align_items": null,
1004
+ "align_self": null,
1005
+ "border": null,
1006
+ "bottom": null,
1007
+ "display": null,
1008
+ "flex": null,
1009
+ "flex_flow": null,
1010
+ "grid_area": null,
1011
+ "grid_auto_columns": null,
1012
+ "grid_auto_flow": null,
1013
+ "grid_auto_rows": null,
1014
+ "grid_column": null,
1015
+ "grid_gap": null,
1016
+ "grid_row": null,
1017
+ "grid_template_areas": null,
1018
+ "grid_template_columns": null,
1019
+ "grid_template_rows": null,
1020
+ "height": null,
1021
+ "justify_content": null,
1022
+ "justify_items": null,
1023
+ "left": null,
1024
+ "margin": null,
1025
+ "max_height": null,
1026
+ "max_width": null,
1027
+ "min_height": null,
1028
+ "min_width": null,
1029
+ "object_fit": null,
1030
+ "object_position": null,
1031
+ "order": null,
1032
+ "overflow": null,
1033
+ "overflow_x": null,
1034
+ "overflow_y": null,
1035
+ "padding": null,
1036
+ "right": null,
1037
+ "top": null,
1038
+ "visibility": null,
1039
+ "width": null
1040
+ }
1041
+ },
1042
+ "18b1b7454e0e48319b05b5a993f47cbd": {
1043
+ "model_module": "@jupyter-widgets/controls",
1044
+ "model_module_version": "1.5.0",
1045
+ "model_name": "HBoxModel",
1046
+ "state": {
1047
+ "_dom_classes": [],
1048
+ "_model_module": "@jupyter-widgets/controls",
1049
+ "_model_module_version": "1.5.0",
1050
+ "_model_name": "HBoxModel",
1051
+ "_view_count": null,
1052
+ "_view_module": "@jupyter-widgets/controls",
1053
+ "_view_module_version": "1.5.0",
1054
+ "_view_name": "HBoxView",
1055
+ "box_style": "",
1056
+ "children": [
1057
+ "IPY_MODEL_6c64067679314b1fbda62eaa0ea1c212",
1058
+ "IPY_MODEL_dd388a4e894d44f5b3b44bf9a8804c11",
1059
+ "IPY_MODEL_212e14b55d984de5958562a66c4e3a79"
1060
+ ],
1061
+ "layout": "IPY_MODEL_adba34cf672e4fde858e50d870323f36"
1062
+ }
1063
+ },
1064
+ "212e14b55d984de5958562a66c4e3a79": {
1065
+ "model_module": "@jupyter-widgets/controls",
1066
+ "model_module_version": "1.5.0",
1067
+ "model_name": "HTMLModel",
1068
+ "state": {
1069
+ "_dom_classes": [],
1070
+ "_model_module": "@jupyter-widgets/controls",
1071
+ "_model_module_version": "1.5.0",
1072
+ "_model_name": "HTMLModel",
1073
+ "_view_count": null,
1074
+ "_view_module": "@jupyter-widgets/controls",
1075
+ "_view_module_version": "1.5.0",
1076
+ "_view_name": "HTMLView",
1077
+ "description": "",
1078
+ "description_tooltip": null,
1079
+ "layout": "IPY_MODEL_13642f0614b145fd9bb3043dac743a62",
1080
+ "placeholder": "​",
1081
+ "style": "IPY_MODEL_016218d199874567ad89e22e2df78d64",
1082
+ "value": " 5/5 [00:08&lt;00:00, 1.61s/it]"
1083
+ }
1084
+ },
1085
+ "307cdcab85c640fa9230e743fc8a4849": {
1086
+ "model_module": "@jupyter-widgets/base",
1087
+ "model_module_version": "1.2.0",
1088
+ "model_name": "LayoutModel",
1089
+ "state": {
1090
+ "_model_module": "@jupyter-widgets/base",
1091
+ "_model_module_version": "1.2.0",
1092
+ "_model_name": "LayoutModel",
1093
+ "_view_count": null,
1094
+ "_view_module": "@jupyter-widgets/base",
1095
+ "_view_module_version": "1.2.0",
1096
+ "_view_name": "LayoutView",
1097
+ "align_content": null,
1098
+ "align_items": null,
1099
+ "align_self": null,
1100
+ "border": null,
1101
+ "bottom": null,
1102
+ "display": null,
1103
+ "flex": null,
1104
+ "flex_flow": null,
1105
+ "grid_area": null,
1106
+ "grid_auto_columns": null,
1107
+ "grid_auto_flow": null,
1108
+ "grid_auto_rows": null,
1109
+ "grid_column": null,
1110
+ "grid_gap": null,
1111
+ "grid_row": null,
1112
+ "grid_template_areas": null,
1113
+ "grid_template_columns": null,
1114
+ "grid_template_rows": null,
1115
+ "height": null,
1116
+ "justify_content": null,
1117
+ "justify_items": null,
1118
+ "left": null,
1119
+ "margin": null,
1120
+ "max_height": null,
1121
+ "max_width": null,
1122
+ "min_height": null,
1123
+ "min_width": null,
1124
+ "object_fit": null,
1125
+ "object_position": null,
1126
+ "order": null,
1127
+ "overflow": null,
1128
+ "overflow_x": null,
1129
+ "overflow_y": null,
1130
+ "padding": null,
1131
+ "right": null,
1132
+ "top": null,
1133
+ "visibility": null,
1134
+ "width": null
1135
+ }
1136
+ },
1137
+ "65aed14b169c438495157fc4afebbb6e": {
1138
+ "model_module": "@jupyter-widgets/controls",
1139
+ "model_module_version": "1.5.0",
1140
+ "model_name": "DescriptionStyleModel",
1141
+ "state": {
1142
+ "_model_module": "@jupyter-widgets/controls",
1143
+ "_model_module_version": "1.5.0",
1144
+ "_model_name": "DescriptionStyleModel",
1145
+ "_view_count": null,
1146
+ "_view_module": "@jupyter-widgets/base",
1147
+ "_view_module_version": "1.2.0",
1148
+ "_view_name": "StyleView",
1149
+ "description_width": ""
1150
+ }
1151
+ },
1152
+ "6c64067679314b1fbda62eaa0ea1c212": {
1153
+ "model_module": "@jupyter-widgets/controls",
1154
+ "model_module_version": "1.5.0",
1155
+ "model_name": "HTMLModel",
1156
+ "state": {
1157
+ "_dom_classes": [],
1158
+ "_model_module": "@jupyter-widgets/controls",
1159
+ "_model_module_version": "1.5.0",
1160
+ "_model_name": "HTMLModel",
1161
+ "_view_count": null,
1162
+ "_view_module": "@jupyter-widgets/controls",
1163
+ "_view_module_version": "1.5.0",
1164
+ "_view_name": "HTMLView",
1165
+ "description": "",
1166
+ "description_tooltip": null,
1167
+ "layout": "IPY_MODEL_03e57d986f05468b899b8030b167a552",
1168
+ "placeholder": "​",
1169
+ "style": "IPY_MODEL_65aed14b169c438495157fc4afebbb6e",
1170
+ "value": "100%"
1171
+ }
1172
+ },
1173
+ "adba34cf672e4fde858e50d870323f36": {
1174
+ "model_module": "@jupyter-widgets/base",
1175
+ "model_module_version": "1.2.0",
1176
+ "model_name": "LayoutModel",
1177
+ "state": {
1178
+ "_model_module": "@jupyter-widgets/base",
1179
+ "_model_module_version": "1.2.0",
1180
+ "_model_name": "LayoutModel",
1181
+ "_view_count": null,
1182
+ "_view_module": "@jupyter-widgets/base",
1183
+ "_view_module_version": "1.2.0",
1184
+ "_view_name": "LayoutView",
1185
+ "align_content": null,
1186
+ "align_items": null,
1187
+ "align_self": null,
1188
+ "border": null,
1189
+ "bottom": null,
1190
+ "display": null,
1191
+ "flex": null,
1192
+ "flex_flow": null,
1193
+ "grid_area": null,
1194
+ "grid_auto_columns": null,
1195
+ "grid_auto_flow": null,
1196
+ "grid_auto_rows": null,
1197
+ "grid_column": null,
1198
+ "grid_gap": null,
1199
+ "grid_row": null,
1200
+ "grid_template_areas": null,
1201
+ "grid_template_columns": null,
1202
+ "grid_template_rows": null,
1203
+ "height": null,
1204
+ "justify_content": null,
1205
+ "justify_items": null,
1206
+ "left": null,
1207
+ "margin": null,
1208
+ "max_height": null,
1209
+ "max_width": null,
1210
+ "min_height": null,
1211
+ "min_width": null,
1212
+ "object_fit": null,
1213
+ "object_position": null,
1214
+ "order": null,
1215
+ "overflow": null,
1216
+ "overflow_x": null,
1217
+ "overflow_y": null,
1218
+ "padding": null,
1219
+ "right": null,
1220
+ "top": null,
1221
+ "visibility": null,
1222
+ "width": null
1223
+ }
1224
+ },
1225
+ "dd388a4e894d44f5b3b44bf9a8804c11": {
1226
+ "model_module": "@jupyter-widgets/controls",
1227
+ "model_module_version": "1.5.0",
1228
+ "model_name": "FloatProgressModel",
1229
+ "state": {
1230
+ "_dom_classes": [],
1231
+ "_model_module": "@jupyter-widgets/controls",
1232
+ "_model_module_version": "1.5.0",
1233
+ "_model_name": "FloatProgressModel",
1234
+ "_view_count": null,
1235
+ "_view_module": "@jupyter-widgets/controls",
1236
+ "_view_module_version": "1.5.0",
1237
+ "_view_name": "ProgressView",
1238
+ "bar_style": "success",
1239
+ "description": "",
1240
+ "description_tooltip": null,
1241
+ "layout": "IPY_MODEL_307cdcab85c640fa9230e743fc8a4849",
1242
+ "max": 5.0,
1243
+ "min": 0.0,
1244
+ "orientation": "horizontal",
1245
+ "style": "IPY_MODEL_03bdf5df19ba49cd96e4b5209e92c6ab",
1246
+ "value": 5.0
1247
+ }
1248
+ }
1249
+ },
1250
+ "version_major": 2,
1251
+ "version_minor": 0
1252
+ }
1253
+ }
1254
+ },
1255
+ "nbformat": 4,
1256
+ "nbformat_minor": 5
1257
+ }
XgBoost.py/XgBoost_ver3.ipynb ADDED
@@ -0,0 +1,1431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "4ce707b2",
7
+ "metadata": {
8
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
9
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
10
+ "execution": {
11
+ "iopub.execute_input": "2023-07-21T15:52:46.226344Z",
12
+ "iopub.status.busy": "2023-07-21T15:52:46.225376Z",
13
+ "iopub.status.idle": "2023-07-21T15:52:46.245756Z",
14
+ "shell.execute_reply": "2023-07-21T15:52:46.244748Z"
15
+ },
16
+ "papermill": {
17
+ "duration": 0.034273,
18
+ "end_time": "2023-07-21T15:52:46.248766",
19
+ "exception": false,
20
+ "start_time": "2023-07-21T15:52:46.214493",
21
+ "status": "completed"
22
+ },
23
+ "tags": []
24
+ },
25
+ "outputs": [
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv\n",
31
+ "/kaggle/input/icr-identify-age-related-conditions/greeks.csv\n",
32
+ "/kaggle/input/icr-identify-age-related-conditions/train.csv\n",
33
+ "/kaggle/input/icr-identify-age-related-conditions/test.csv\n",
34
+ "/kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl\n",
35
+ "/kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_42.cpkt\n",
36
+ "/kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_100.cpkt\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "# This Python 3 environment comes with many helpful analytics libraries installed\n",
42
+ "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
43
+ "# For example, here's several helpful packages to load\n",
44
+ "\n",
45
+ "import numpy as np # linear algebra\n",
46
+ "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
47
+ "\n",
48
+ "# Input data files are available in the read-only \"../input/\" directory\n",
49
+ "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
50
+ "\n",
51
+ "import os\n",
52
+ "for dirname, _, filenames in os.walk('/kaggle/input'):\n",
53
+ " for filename in filenames:\n",
54
+ " print(os.path.join(dirname, filename))\n",
55
+ "\n",
56
+ "# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
57
+ "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 2,
63
+ "id": "9dd44f97",
64
+ "metadata": {
65
+ "execution": {
66
+ "iopub.execute_input": "2023-07-21T15:52:46.269079Z",
67
+ "iopub.status.busy": "2023-07-21T15:52:46.268284Z",
68
+ "iopub.status.idle": "2023-07-21T15:53:21.724129Z",
69
+ "shell.execute_reply": "2023-07-21T15:53:21.722662Z"
70
+ },
71
+ "papermill": {
72
+ "duration": 35.469555,
73
+ "end_time": "2023-07-21T15:53:21.727323",
74
+ "exception": false,
75
+ "start_time": "2023-07-21T15:52:46.257768",
76
+ "status": "completed"
77
+ },
78
+ "tags": []
79
+ },
80
+ "outputs": [
81
+ {
82
+ "name": "stdout",
83
+ "output_type": "stream",
84
+ "text": [
85
+ "Processing /kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl\r\n",
86
+ "Requirement already satisfied: numpy>=1.21.2 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (1.23.5)\r\n",
87
+ "Requirement already satisfied: pyyaml>=5.4.1 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (6.0)\r\n",
88
+ "Requirement already satisfied: requests>=2.23.0 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (2.31.0)\r\n",
89
+ "Requirement already satisfied: scikit-learn>=0.24.2 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (1.2.2)\r\n",
90
+ "Requirement already satisfied: torch>=1.9.0 in /opt/conda/lib/python3.10/site-packages (from tabpfn==0.1.9) (2.0.0+cpu)\r\n",
91
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (3.1.0)\r\n",
92
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (3.4)\r\n",
93
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (1.26.15)\r\n",
94
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.23.0->tabpfn==0.1.9) (2023.5.7)\r\n",
95
+ "Requirement already satisfied: scipy>=1.3.2 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (1.11.1)\r\n",
96
+ "Requirement already satisfied: joblib>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (1.2.0)\r\n",
97
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.24.2->tabpfn==0.1.9) (3.1.0)\r\n",
98
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.12.2)\r\n",
99
+ "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (4.6.3)\r\n",
100
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (1.12)\r\n",
101
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.1)\r\n",
102
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.9.0->tabpfn==0.1.9) (3.1.2)\r\n",
103
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.9.0->tabpfn==0.1.9) (2.1.3)\r\n",
104
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.9.0->tabpfn==0.1.9) (1.3.0)\r\n",
105
+ "Installing collected packages: tabpfn\r\n",
106
+ "Successfully installed tabpfn-0.1.9\r\n"
107
+ ]
108
+ }
109
+ ],
110
+ "source": [
111
+ "!pip install /kaggle/input/tabpfn-019-whl/tabpfn-0.1.9-py3-none-any.whl"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 3,
117
+ "id": "c9b8d056",
118
+ "metadata": {
119
+ "execution": {
120
+ "iopub.execute_input": "2023-07-21T15:53:21.748477Z",
121
+ "iopub.status.busy": "2023-07-21T15:53:21.748021Z",
122
+ "iopub.status.idle": "2023-07-21T15:53:24.917894Z",
123
+ "shell.execute_reply": "2023-07-21T15:53:24.916503Z"
124
+ },
125
+ "papermill": {
126
+ "duration": 3.183937,
127
+ "end_time": "2023-07-21T15:53:24.920802",
128
+ "exception": false,
129
+ "start_time": "2023-07-21T15:53:21.736865",
130
+ "status": "completed"
131
+ },
132
+ "tags": []
133
+ },
134
+ "outputs": [],
135
+ "source": [
136
+ "!mkdir /opt/conda/lib/python3.10/site-packages/tabpfn/models_diff\n",
137
+ "!cp /kaggle/input/tabpfn-019-whl/prior_diff_real_checkpoint_n_0_epoch_100.cpkt /opt/conda/lib/python3.10/site-packages/tabpfn/models_diff/"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 4,
143
+ "id": "a30dcb9b",
144
+ "metadata": {
145
+ "execution": {
146
+ "iopub.execute_input": "2023-07-21T15:53:24.945098Z",
147
+ "iopub.status.busy": "2023-07-21T15:53:24.944662Z",
148
+ "iopub.status.idle": "2023-07-21T15:53:32.364411Z",
149
+ "shell.execute_reply": "2023-07-21T15:53:32.363046Z"
150
+ },
151
+ "papermill": {
152
+ "duration": 7.435592,
153
+ "end_time": "2023-07-21T15:53:32.367513",
154
+ "exception": false,
155
+ "start_time": "2023-07-21T15:53:24.931921",
156
+ "status": "completed"
157
+ },
158
+ "tags": []
159
+ },
160
+ "outputs": [
161
+ {
162
+ "name": "stderr",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
166
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
167
+ ]
168
+ }
169
+ ],
170
+ "source": [
171
+ "import numpy as np\n",
172
+ "import pandas as pd\n",
173
+ "from sklearn.preprocessing import LabelEncoder,normalize\n",
174
+ "from sklearn.ensemble import GradientBoostingClassifier,RandomForestClassifier\n",
175
+ "from sklearn.metrics import accuracy_score\n",
176
+ "from sklearn.impute import SimpleImputer\n",
177
+ "import imblearn\n",
178
+ "from imblearn.over_sampling import RandomOverSampler\n",
179
+ "from imblearn.under_sampling import RandomUnderSampler\n",
180
+ "import xgboost\n",
181
+ "import inspect\n",
182
+ "from collections import defaultdict\n",
183
+ "from tabpfn import TabPFNClassifier\n",
184
+ "import lightgbm as lgb\n",
185
+ "import warnings\n",
186
+ "warnings.filterwarnings('ignore')\n",
187
+ "from sklearn.model_selection import KFold as KF, GridSearchCV\n",
188
+ "from tqdm.notebook import tqdm\n",
189
+ "from datetime import datetime"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 5,
195
+ "id": "f0d7a8de",
196
+ "metadata": {
197
+ "execution": {
198
+ "iopub.execute_input": "2023-07-21T15:53:32.388083Z",
199
+ "iopub.status.busy": "2023-07-21T15:53:32.387679Z",
200
+ "iopub.status.idle": "2023-07-21T15:53:32.456602Z",
201
+ "shell.execute_reply": "2023-07-21T15:53:32.455535Z"
202
+ },
203
+ "papermill": {
204
+ "duration": 0.082612,
205
+ "end_time": "2023-07-21T15:53:32.459479",
206
+ "exception": false,
207
+ "start_time": "2023-07-21T15:53:32.376867",
208
+ "status": "completed"
209
+ },
210
+ "tags": []
211
+ },
212
+ "outputs": [],
213
+ "source": [
214
+ "\n",
215
+ "train_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/train.csv\")\n",
216
+ "test_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/test.csv\")\n",
217
+ "greeks_df = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/greeks.csv\")\n",
218
+ "sample_submission = pd.read_csv(\"/kaggle/input/icr-identify-age-related-conditions/sample_submission.csv\")"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 6,
224
+ "id": "4f425c0c",
225
+ "metadata": {
226
+ "execution": {
227
+ "iopub.execute_input": "2023-07-21T15:53:32.480524Z",
228
+ "iopub.status.busy": "2023-07-21T15:53:32.479901Z",
229
+ "iopub.status.idle": "2023-07-21T15:53:32.496468Z",
230
+ "shell.execute_reply": "2023-07-21T15:53:32.495077Z"
231
+ },
232
+ "papermill": {
233
+ "duration": 0.029964,
234
+ "end_time": "2023-07-21T15:53:32.498999",
235
+ "exception": false,
236
+ "start_time": "2023-07-21T15:53:32.469035",
237
+ "status": "completed"
238
+ },
239
+ "tags": []
240
+ },
241
+ "outputs": [],
242
+ "source": [
243
+ "\n",
244
+ "f_c = train_df.EJ.unique()[0]\n",
245
+ "train_df.EJ = train_df.EJ.eq(f_c).astype('int')\n",
246
+ "test_df.EJ = test_df.EJ.eq(f_c).astype('int')"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": 7,
252
+ "id": "a2d05095",
253
+ "metadata": {
254
+ "execution": {
255
+ "iopub.execute_input": "2023-07-21T15:53:32.519923Z",
256
+ "iopub.status.busy": "2023-07-21T15:53:32.519160Z",
257
+ "iopub.status.idle": "2023-07-21T15:53:32.526632Z",
258
+ "shell.execute_reply": "2023-07-21T15:53:32.525280Z"
259
+ },
260
+ "papermill": {
261
+ "duration": 0.020616,
262
+ "end_time": "2023-07-21T15:53:32.529035",
263
+ "exception": false,
264
+ "start_time": "2023-07-21T15:53:32.508419",
265
+ "status": "completed"
266
+ },
267
+ "tags": []
268
+ },
269
+ "outputs": [],
270
+ "source": [
271
+ "\n",
272
+ "def random_under_sampler(df):\n",
273
+ " neg, pos = np.bincount(df['Class'])\n",
274
+ " one_df = df.loc[df['Class'] == 1]\n",
275
+ " zero_df = df.loc[df['Class'] == 0]\n",
276
+ " zero_df = zero_df.sample(n=pos)\n",
277
+ " undersampled_df = pd.concat([zero_df, one_df])\n",
278
+ " return undersampled_df.sample(frac = 1)\n"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 8,
284
+ "id": "4dbfd096",
285
+ "metadata": {
286
+ "execution": {
287
+ "iopub.execute_input": "2023-07-21T15:53:32.550312Z",
288
+ "iopub.status.busy": "2023-07-21T15:53:32.549074Z",
289
+ "iopub.status.idle": "2023-07-21T15:53:32.563566Z",
290
+ "shell.execute_reply": "2023-07-21T15:53:32.562452Z"
291
+ },
292
+ "papermill": {
293
+ "duration": 0.028002,
294
+ "end_time": "2023-07-21T15:53:32.566324",
295
+ "exception": false,
296
+ "start_time": "2023-07-21T15:53:32.538322",
297
+ "status": "completed"
298
+ },
299
+ "tags": []
300
+ },
301
+ "outputs": [],
302
+ "source": [
303
+ "train_df_good = random_under_sampler(train_df)"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 9,
309
+ "id": "6bd48549",
310
+ "metadata": {
311
+ "execution": {
312
+ "iopub.execute_input": "2023-07-21T15:53:32.586802Z",
313
+ "iopub.status.busy": "2023-07-21T15:53:32.586406Z",
314
+ "iopub.status.idle": "2023-07-21T15:53:32.594042Z",
315
+ "shell.execute_reply": "2023-07-21T15:53:32.592939Z"
316
+ },
317
+ "papermill": {
318
+ "duration": 0.021077,
319
+ "end_time": "2023-07-21T15:53:32.596735",
320
+ "exception": false,
321
+ "start_time": "2023-07-21T15:53:32.575658",
322
+ "status": "completed"
323
+ },
324
+ "tags": []
325
+ },
326
+ "outputs": [
327
+ {
328
+ "data": {
329
+ "text/plain": [
330
+ "(216, 58)"
331
+ ]
332
+ },
333
+ "execution_count": 9,
334
+ "metadata": {},
335
+ "output_type": "execute_result"
336
+ }
337
+ ],
338
+ "source": [
339
+ "\n",
340
+ "train_df_good.shape"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": 10,
346
+ "id": "fe4cad4c",
347
+ "metadata": {
348
+ "execution": {
349
+ "iopub.execute_input": "2023-07-21T15:53:32.617381Z",
350
+ "iopub.status.busy": "2023-07-21T15:53:32.616958Z",
351
+ "iopub.status.idle": "2023-07-21T15:53:32.626930Z",
352
+ "shell.execute_reply": "2023-07-21T15:53:32.625603Z"
353
+ },
354
+ "papermill": {
355
+ "duration": 0.023274,
356
+ "end_time": "2023-07-21T15:53:32.629359",
357
+ "exception": false,
358
+ "start_time": "2023-07-21T15:53:32.606085",
359
+ "status": "completed"
360
+ },
361
+ "tags": []
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "predictor_columns = [n for n in train_df.columns if n != 'Class' and n != 'Id']\n",
366
+ "x= train_df[predictor_columns]\n",
367
+ "y = train_df['Class']"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 11,
373
+ "id": "dc4251ec",
374
+ "metadata": {
375
+ "execution": {
376
+ "iopub.execute_input": "2023-07-21T15:53:32.650279Z",
377
+ "iopub.status.busy": "2023-07-21T15:53:32.649843Z",
378
+ "iopub.status.idle": "2023-07-21T15:53:32.654686Z",
379
+ "shell.execute_reply": "2023-07-21T15:53:32.653811Z"
380
+ },
381
+ "papermill": {
382
+ "duration": 0.017862,
383
+ "end_time": "2023-07-21T15:53:32.656879",
384
+ "exception": false,
385
+ "start_time": "2023-07-21T15:53:32.639017",
386
+ "status": "completed"
387
+ },
388
+ "tags": []
389
+ },
390
+ "outputs": [],
391
+ "source": [
392
+ "cv_outer = KF(n_splits = 10, shuffle=True, random_state=42)\n",
393
+ "cv_inner = KF(n_splits = 5, shuffle=True, random_state=42)"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 12,
399
+ "id": "55f481b0",
400
+ "metadata": {
401
+ "execution": {
402
+ "iopub.execute_input": "2023-07-21T15:53:32.677484Z",
403
+ "iopub.status.busy": "2023-07-21T15:53:32.677040Z",
404
+ "iopub.status.idle": "2023-07-21T15:53:32.684784Z",
405
+ "shell.execute_reply": "2023-07-21T15:53:32.683665Z"
406
+ },
407
+ "papermill": {
408
+ "duration": 0.020651,
409
+ "end_time": "2023-07-21T15:53:32.687050",
410
+ "exception": false,
411
+ "start_time": "2023-07-21T15:53:32.666399",
412
+ "status": "completed"
413
+ },
414
+ "tags": []
415
+ },
416
+ "outputs": [],
417
+ "source": [
418
+ "def balanced_log_loss(y_true, y_pred):\n",
419
+ " N_0 = np.sum(1 - y_true)\n",
420
+ " N_1 = np.sum(y_true)\n",
421
+ " \n",
422
+ " w_0 = 1 / N_0\n",
423
+ " w_1 = 1 / N_1\n",
424
+ " p_1 = np.clip(y_pred, 1e-15, 1 - 1e-15)\n",
425
+ " p_0 = 1 - p_1\n",
426
+ " log_loss_0 = -np.sum((1 - y_true) * np.log(p_0))\n",
427
+ " log_loss_1 = -np.sum(y_true * np.log(p_1))\n",
428
+ " balanced_log_loss = 2*(w_0 * log_loss_0 + w_1 * log_loss_1) / (w_0 + w_1)\n",
429
+ " return balanced_log_loss/(N_0+N_1)"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": 13,
435
+ "id": "8a4e9d35",
436
+ "metadata": {
437
+ "execution": {
438
+ "iopub.execute_input": "2023-07-21T15:53:32.707920Z",
439
+ "iopub.status.busy": "2023-07-21T15:53:32.707502Z",
440
+ "iopub.status.idle": "2023-07-21T15:53:32.721521Z",
441
+ "shell.execute_reply": "2023-07-21T15:53:32.720277Z"
442
+ },
443
+ "papermill": {
444
+ "duration": 0.027412,
445
+ "end_time": "2023-07-21T15:53:32.723937",
446
+ "exception": false,
447
+ "start_time": "2023-07-21T15:53:32.696525",
448
+ "status": "completed"
449
+ },
450
+ "tags": []
451
+ },
452
+ "outputs": [],
453
+ "source": [
454
+ "class Ensemble():\n",
455
+ " def __init__(self):\n",
456
+ " self.imputer = SimpleImputer(missing_values=np.nan, strategy='median')\n",
457
+ " self.classifiers =[xgboost.XGBClassifier(n_estimators=100,max_depth=3,learning_rate=0.2,subsample=0.9,colsample_bytree=0.85),\n",
458
+ " xgboost.XGBClassifier(),\n",
459
+ " TabPFNClassifier(N_ensemble_configurations=128),\n",
460
+ " TabPFNClassifier(N_ensemble_configurations=48)]\n",
461
+ " \n",
462
+ " def fit(self,X,y):\n",
463
+ " y = y.values\n",
464
+ " unique_classes, y = np.unique(y, return_inverse=True)\n",
465
+ " self.classes_ = unique_classes\n",
466
+ " first_category = X.EJ.unique()[0]\n",
467
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n",
468
+ " X = self.imputer.fit_transform(X)\n",
469
+ "\n",
470
+ " for classifier in self.classifiers:\n",
471
+ " if classifier==self.classifiers[2] or classifier==self.classifiers[3]:\n",
472
+ " classifier.fit(X,y,overwrite_warning =True)\n",
473
+ " else :\n",
474
+ " classifier.fit(X, y)\n",
475
+ " \n",
476
+ " def predict_proba(self, x):\n",
477
+ " x = self.imputer.transform(x)\n",
478
+ " probabilities = np.stack([classifier.predict_proba(x) for classifier in self.classifiers])\n",
479
+ " averaged_probabilities = np.mean(probabilities, axis=0)\n",
480
+ " class_0_est_instances = averaged_probabilities[:, 0].sum()\n",
481
+ " others_est_instances = averaged_probabilities[:, 1:].sum()\n",
482
+ " # Weighted probabilities based on class imbalance\n",
483
+ " new_probabilities = averaged_probabilities * np.array([[1/(class_0_est_instances if i==0 else others_est_instances) for i in range(averaged_probabilities.shape[1])]])\n",
484
+ " return new_probabilities / np.sum(new_probabilities, axis=1, keepdims=1) "
485
+ ]
486
+ },
487
+ {
488
+ "cell_type": "code",
489
+ "execution_count": 14,
490
+ "id": "2d8e9bac",
491
+ "metadata": {
492
+ "execution": {
493
+ "iopub.execute_input": "2023-07-21T15:53:32.745072Z",
494
+ "iopub.status.busy": "2023-07-21T15:53:32.744663Z",
495
+ "iopub.status.idle": "2023-07-21T15:53:32.757423Z",
496
+ "shell.execute_reply": "2023-07-21T15:53:32.756143Z"
497
+ },
498
+ "papermill": {
499
+ "duration": 0.026546,
500
+ "end_time": "2023-07-21T15:53:32.760075",
501
+ "exception": false,
502
+ "start_time": "2023-07-21T15:53:32.733529",
503
+ "status": "completed"
504
+ },
505
+ "tags": []
506
+ },
507
+ "outputs": [],
508
+ "source": [
509
+ "def training(model, x,y,y_meta):\n",
510
+ " outer_results = list()\n",
511
+ " best_loss = np.inf\n",
512
+ " split = 0\n",
513
+ " splits = 5\n",
514
+ " for train_idx,val_idx in tqdm(cv_inner.split(x), total = splits):\n",
515
+ " split+=1\n",
516
+ " x_train, x_val = x.iloc[train_idx],x.iloc[val_idx]\n",
517
+ " y_train, y_val = y_meta.iloc[train_idx], y.iloc[val_idx]\n",
518
+ " \n",
519
+ " model.fit(x_train, y_train)\n",
520
+ " y_pred = model.predict_proba(x_val)\n",
521
+ " probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
522
+ " p0 = probabilities[:,:1]\n",
523
+ " p0[p0 > 0.86] = 1\n",
524
+ " p0[p0 < 0.14] = 0\n",
525
+ " y_p = np.empty((y_pred.shape[0],))\n",
526
+ " for i in range(y_pred.shape[0]):\n",
527
+ " if p0[i]>=0.5:\n",
528
+ " y_p[i]= False\n",
529
+ " else :\n",
530
+ " y_p[i]=True\n",
531
+ " y_p = y_p.astype(int)\n",
532
+ " loss = balanced_log_loss(y_val,y_p)\n",
533
+ "\n",
534
+ " if loss<best_loss:\n",
535
+ " best_model = model\n",
536
+ " best_loss = loss\n",
537
+ " print('best_model_saved')\n",
538
+ " outer_results.append(loss)\n",
539
+ " print('>val_loss=%.5f, split = %.1f' % (loss,split))\n",
540
+ " print('LOSS: %.5f' % (np.mean(outer_results)))\n",
541
+ " return best_model"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": 15,
547
+ "id": "eac85a04",
548
+ "metadata": {
549
+ "execution": {
550
+ "iopub.execute_input": "2023-07-21T15:53:32.780758Z",
551
+ "iopub.status.busy": "2023-07-21T15:53:32.780354Z",
552
+ "iopub.status.idle": "2023-07-21T15:53:32.804335Z",
553
+ "shell.execute_reply": "2023-07-21T15:53:32.802984Z"
554
+ },
555
+ "papermill": {
556
+ "duration": 0.03723,
557
+ "end_time": "2023-07-21T15:53:32.806892",
558
+ "exception": false,
559
+ "start_time": "2023-07-21T15:53:32.769662",
560
+ "status": "completed"
561
+ },
562
+ "tags": []
563
+ },
564
+ "outputs": [],
565
+ "source": [
566
+ "times = greeks_df.Epsilon.copy()\n",
567
+ "times[greeks_df.Epsilon != 'Unknown'] = greeks_df.Epsilon[greeks_df.Epsilon != 'Unknown'].map(lambda x: datetime.strptime(x,'%m/%d/%Y').toordinal())\n",
568
+ "times[greeks_df.Epsilon == 'Unknown'] = np.nan"
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "execution_count": 16,
574
+ "id": "13f7db8f",
575
+ "metadata": {
576
+ "execution": {
577
+ "iopub.execute_input": "2023-07-21T15:53:32.827820Z",
578
+ "iopub.status.busy": "2023-07-21T15:53:32.827408Z",
579
+ "iopub.status.idle": "2023-07-21T15:53:32.842949Z",
580
+ "shell.execute_reply": "2023-07-21T15:53:32.841736Z"
581
+ },
582
+ "papermill": {
583
+ "duration": 0.029024,
584
+ "end_time": "2023-07-21T15:53:32.845462",
585
+ "exception": false,
586
+ "start_time": "2023-07-21T15:53:32.816438",
587
+ "status": "completed"
588
+ },
589
+ "tags": []
590
+ },
591
+ "outputs": [],
592
+ "source": [
593
+ "train_pred_and_time = pd.concat((train_df, times), axis=1)\n",
594
+ "test_predictors = test_df[predictor_columns]\n",
595
+ "f_c = test_predictors.EJ.unique()[0]\n",
596
+ "test_predictors.EJ = test_predictors.EJ.eq(f_c).astype('int')\n",
597
+ "test_pred_and_time = np.concatenate((test_predictors, np.zeros((len(test_predictors), 1)) + train_pred_and_time.Epsilon.max() + 1), axis=1)"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": 17,
603
+ "id": "7c59e8fa",
604
+ "metadata": {
605
+ "execution": {
606
+ "iopub.execute_input": "2023-07-21T15:53:32.866481Z",
607
+ "iopub.status.busy": "2023-07-21T15:53:32.866028Z",
608
+ "iopub.status.idle": "2023-07-21T15:53:32.909710Z",
609
+ "shell.execute_reply": "2023-07-21T15:53:32.908189Z"
610
+ },
611
+ "papermill": {
612
+ "duration": 0.057234,
613
+ "end_time": "2023-07-21T15:53:32.912386",
614
+ "exception": false,
615
+ "start_time": "2023-07-21T15:53:32.855152",
616
+ "status": "completed"
617
+ },
618
+ "tags": []
619
+ },
620
+ "outputs": [
621
+ {
622
+ "name": "stdout",
623
+ "output_type": "stream",
624
+ "text": [
625
+ "Original dataset shape\n",
626
+ "A 509\n",
627
+ "B 61\n",
628
+ "G 29\n",
629
+ "D 18\n",
630
+ "Name: Alpha, dtype: int64\n",
631
+ "Resample dataset shape\n",
632
+ "B 509\n",
633
+ "A 509\n",
634
+ "D 509\n",
635
+ "G 509\n",
636
+ "Name: Alpha, dtype: int64\n"
637
+ ]
638
+ }
639
+ ],
640
+ "source": [
641
+ "ros = RandomOverSampler(random_state=42)\n",
642
+ "train_ros, y_ros = ros.fit_resample(train_pred_and_time, greeks_df.Alpha)\n",
643
+ "print('Original dataset shape')\n",
644
+ "print(greeks_df.Alpha.value_counts())\n",
645
+ "print('Resample dataset shape')\n",
646
+ "print( y_ros.value_counts())"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": 18,
652
+ "id": "24f93672",
653
+ "metadata": {
654
+ "execution": {
655
+ "iopub.execute_input": "2023-07-21T15:53:32.934654Z",
656
+ "iopub.status.busy": "2023-07-21T15:53:32.934203Z",
657
+ "iopub.status.idle": "2023-07-21T15:53:32.943106Z",
658
+ "shell.execute_reply": "2023-07-21T15:53:32.941927Z"
659
+ },
660
+ "papermill": {
661
+ "duration": 0.023743,
662
+ "end_time": "2023-07-21T15:53:32.945785",
663
+ "exception": false,
664
+ "start_time": "2023-07-21T15:53:32.922042",
665
+ "status": "completed"
666
+ },
667
+ "tags": []
668
+ },
669
+ "outputs": [],
670
+ "source": [
671
+ "\n",
672
+ "x_ros = train_ros.drop(['Class', 'Id'],axis=1)\n",
673
+ "y_ = train_ros.Class"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": 19,
679
+ "id": "3bb22cb9",
680
+ "metadata": {
681
+ "execution": {
682
+ "iopub.execute_input": "2023-07-21T15:53:32.969825Z",
683
+ "iopub.status.busy": "2023-07-21T15:53:32.969410Z",
684
+ "iopub.status.idle": "2023-07-21T15:53:34.012308Z",
685
+ "shell.execute_reply": "2023-07-21T15:53:34.011458Z"
686
+ },
687
+ "papermill": {
688
+ "duration": 1.058482,
689
+ "end_time": "2023-07-21T15:53:34.015114",
690
+ "exception": false,
691
+ "start_time": "2023-07-21T15:53:32.956632",
692
+ "status": "completed"
693
+ },
694
+ "tags": []
695
+ },
696
+ "outputs": [
697
+ {
698
+ "name": "stdout",
699
+ "output_type": "stream",
700
+ "text": [
701
+ "Loading model that can be used for inference only\n",
702
+ "Using a Transformer with 25.82 M parameters\n",
703
+ "Loading model that can be used for inference only\n",
704
+ "Using a Transformer with 25.82 M parameters\n"
705
+ ]
706
+ }
707
+ ],
708
+ "source": [
709
+ "\n",
710
+ "yt = Ensemble()"
711
+ ]
712
+ },
713
+ {
714
+ "cell_type": "code",
715
+ "execution_count": 20,
716
+ "id": "2f4be2ee",
717
+ "metadata": {
718
+ "execution": {
719
+ "iopub.execute_input": "2023-07-21T15:53:34.038390Z",
720
+ "iopub.status.busy": "2023-07-21T15:53:34.037910Z",
721
+ "iopub.status.idle": "2023-07-21T16:36:53.086642Z",
722
+ "shell.execute_reply": "2023-07-21T16:36:53.085438Z"
723
+ },
724
+ "papermill": {
725
+ "duration": 2599.074078,
726
+ "end_time": "2023-07-21T16:36:53.100516",
727
+ "exception": false,
728
+ "start_time": "2023-07-21T15:53:34.026438",
729
+ "status": "completed"
730
+ },
731
+ "tags": []
732
+ },
733
+ "outputs": [
734
+ {
735
+ "data": {
736
+ "application/vnd.jupyter.widget-view+json": {
737
+ "model_id": "1486711fdccc430bb8f19ffe0003cdf5",
738
+ "version_major": 2,
739
+ "version_minor": 0
740
+ },
741
+ "text/plain": [
742
+ " 0%| | 0/5 [00:00<?, ?it/s]"
743
+ ]
744
+ },
745
+ "metadata": {},
746
+ "output_type": "display_data"
747
+ },
748
+ {
749
+ "name": "stdout",
750
+ "output_type": "stream",
751
+ "text": [
752
+ "best_model_saved\n",
753
+ ">val_loss=0.12283, split = 1.0\n"
754
+ ]
755
+ },
756
+ {
757
+ "name": "stderr",
758
+ "output_type": "stream",
759
+ "text": [
760
+ "/tmp/ipykernel_20/772101332.py:14: SettingWithCopyWarning: \n",
761
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
762
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
763
+ "\n",
764
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
765
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
766
+ ]
767
+ },
768
+ {
769
+ "name": "stdout",
770
+ "output_type": "stream",
771
+ "text": [
772
+ "best_model_saved\n",
773
+ ">val_loss=0.00000, split = 2.0\n"
774
+ ]
775
+ },
776
+ {
777
+ "name": "stderr",
778
+ "output_type": "stream",
779
+ "text": [
780
+ "/tmp/ipykernel_20/772101332.py:14: SettingWithCopyWarning: \n",
781
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
782
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
783
+ "\n",
784
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
785
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
786
+ ]
787
+ },
788
+ {
789
+ "name": "stdout",
790
+ "output_type": "stream",
791
+ "text": [
792
+ ">val_loss=0.00000, split = 3.0\n"
793
+ ]
794
+ },
795
+ {
796
+ "name": "stderr",
797
+ "output_type": "stream",
798
+ "text": [
799
+ "/tmp/ipykernel_20/772101332.py:14: SettingWithCopyWarning: \n",
800
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
801
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
802
+ "\n",
803
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
804
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
805
+ ]
806
+ },
807
+ {
808
+ "name": "stdout",
809
+ "output_type": "stream",
810
+ "text": [
811
+ ">val_loss=0.13220, split = 4.0\n"
812
+ ]
813
+ },
814
+ {
815
+ "name": "stderr",
816
+ "output_type": "stream",
817
+ "text": [
818
+ "/tmp/ipykernel_20/772101332.py:14: SettingWithCopyWarning: \n",
819
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
820
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
821
+ "\n",
822
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
823
+ " X.EJ = X.EJ.eq(first_category).astype('int')\n"
824
+ ]
825
+ },
826
+ {
827
+ "name": "stdout",
828
+ "output_type": "stream",
829
+ "text": [
830
+ ">val_loss=0.13386, split = 5.0\n",
831
+ "LOSS: 0.07778\n"
832
+ ]
833
+ }
834
+ ],
835
+ "source": [
836
+ "\n",
837
+ "m = training(yt,x_ros,y_,y_ros)"
838
+ ]
839
+ },
840
+ {
841
+ "cell_type": "code",
842
+ "execution_count": 21,
843
+ "id": "d96b306f",
844
+ "metadata": {
845
+ "execution": {
846
+ "iopub.execute_input": "2023-07-21T16:36:53.124238Z",
847
+ "iopub.status.busy": "2023-07-21T16:36:53.123776Z",
848
+ "iopub.status.idle": "2023-07-21T16:36:53.134302Z",
849
+ "shell.execute_reply": "2023-07-21T16:36:53.133149Z"
850
+ },
851
+ "papermill": {
852
+ "duration": 0.025325,
853
+ "end_time": "2023-07-21T16:36:53.136690",
854
+ "exception": false,
855
+ "start_time": "2023-07-21T16:36:53.111365",
856
+ "status": "completed"
857
+ },
858
+ "tags": []
859
+ },
860
+ "outputs": [
861
+ {
862
+ "data": {
863
+ "text/plain": [
864
+ "1 0.75\n",
865
+ "0 0.25\n",
866
+ "Name: Class, dtype: float64"
867
+ ]
868
+ },
869
+ "execution_count": 21,
870
+ "metadata": {},
871
+ "output_type": "execute_result"
872
+ }
873
+ ],
874
+ "source": [
875
+ "y_.value_counts()/y_.shape[0]"
876
+ ]
877
+ },
878
+ {
879
+ "cell_type": "code",
880
+ "execution_count": 22,
881
+ "id": "c45bd72b",
882
+ "metadata": {
883
+ "execution": {
884
+ "iopub.execute_input": "2023-07-21T16:36:53.161504Z",
885
+ "iopub.status.busy": "2023-07-21T16:36:53.160656Z",
886
+ "iopub.status.idle": "2023-07-21T16:44:01.043865Z",
887
+ "shell.execute_reply": "2023-07-21T16:44:01.042239Z"
888
+ },
889
+ "papermill": {
890
+ "duration": 427.900911,
891
+ "end_time": "2023-07-21T16:44:01.048717",
892
+ "exception": false,
893
+ "start_time": "2023-07-21T16:36:53.147806",
894
+ "status": "completed"
895
+ },
896
+ "tags": []
897
+ },
898
+ "outputs": [
899
+ {
900
+ "name": "stderr",
901
+ "output_type": "stream",
902
+ "text": [
903
+ "/opt/conda/lib/python3.10/site-packages/sklearn/base.py:439: UserWarning: X does not have valid feature names, but SimpleImputer was fitted with feature names\n",
904
+ " warnings.warn(\n"
905
+ ]
906
+ }
907
+ ],
908
+ "source": [
909
+ "\n",
910
+ "y_pred = m.predict_proba(test_pred_and_time)\n",
911
+ "probabilities = np.concatenate((y_pred[:,:1], np.sum(y_pred[:,1:], 1, keepdims=True)), axis=1)\n",
912
+ "p0 = probabilities[:,:1]\n",
913
+ "p0[p0 > 0.59] = 1\n",
914
+ "p0[p0 < 0.28] = 0"
915
+ ]
916
+ },
917
+ {
918
+ "cell_type": "code",
919
+ "execution_count": 23,
920
+ "id": "967ed14e",
921
+ "metadata": {
922
+ "execution": {
923
+ "iopub.execute_input": "2023-07-21T16:44:01.076361Z",
924
+ "iopub.status.busy": "2023-07-21T16:44:01.075908Z",
925
+ "iopub.status.idle": "2023-07-21T16:44:01.090619Z",
926
+ "shell.execute_reply": "2023-07-21T16:44:01.089411Z"
927
+ },
928
+ "papermill": {
929
+ "duration": 0.030936,
930
+ "end_time": "2023-07-21T16:44:01.093502",
931
+ "exception": false,
932
+ "start_time": "2023-07-21T16:44:01.062566",
933
+ "status": "completed"
934
+ },
935
+ "tags": []
936
+ },
937
+ "outputs": [],
938
+ "source": [
939
+ "\n",
940
+ "submission = pd.DataFrame(test_df[\"Id\"], columns=[\"Id\"])\n",
941
+ "submission[\"class_0\"] = p0\n",
942
+ "submission[\"class_1\"] = 1 - p0\n",
943
+ "submission.to_csv('submission.csv', index=False)"
944
+ ]
945
+ },
946
+ {
947
+ "cell_type": "code",
948
+ "execution_count": 24,
949
+ "id": "f9b4a7c5",
950
+ "metadata": {
951
+ "execution": {
952
+ "iopub.execute_input": "2023-07-21T16:44:01.118941Z",
953
+ "iopub.status.busy": "2023-07-21T16:44:01.118283Z",
954
+ "iopub.status.idle": "2023-07-21T16:44:01.136373Z",
955
+ "shell.execute_reply": "2023-07-21T16:44:01.135160Z"
956
+ },
957
+ "papermill": {
958
+ "duration": 0.03388,
959
+ "end_time": "2023-07-21T16:44:01.139087",
960
+ "exception": false,
961
+ "start_time": "2023-07-21T16:44:01.105207",
962
+ "status": "completed"
963
+ },
964
+ "tags": []
965
+ },
966
+ "outputs": [
967
+ {
968
+ "data": {
969
+ "text/html": [
970
+ "<div>\n",
971
+ "<style scoped>\n",
972
+ " .dataframe tbody tr th:only-of-type {\n",
973
+ " vertical-align: middle;\n",
974
+ " }\n",
975
+ "\n",
976
+ " .dataframe tbody tr th {\n",
977
+ " vertical-align: top;\n",
978
+ " }\n",
979
+ "\n",
980
+ " .dataframe thead th {\n",
981
+ " text-align: right;\n",
982
+ " }\n",
983
+ "</style>\n",
984
+ "<table border=\"1\" class=\"dataframe\">\n",
985
+ " <thead>\n",
986
+ " <tr style=\"text-align: right;\">\n",
987
+ " <th></th>\n",
988
+ " <th>Id</th>\n",
989
+ " <th>class_0</th>\n",
990
+ " <th>class_1</th>\n",
991
+ " </tr>\n",
992
+ " </thead>\n",
993
+ " <tbody>\n",
994
+ " <tr>\n",
995
+ " <th>0</th>\n",
996
+ " <td>00eed32682bb</td>\n",
997
+ " <td>0.5</td>\n",
998
+ " <td>0.5</td>\n",
999
+ " </tr>\n",
1000
+ " <tr>\n",
1001
+ " <th>1</th>\n",
1002
+ " <td>010ebe33f668</td>\n",
1003
+ " <td>0.5</td>\n",
1004
+ " <td>0.5</td>\n",
1005
+ " </tr>\n",
1006
+ " <tr>\n",
1007
+ " <th>2</th>\n",
1008
+ " <td>02fa521e1838</td>\n",
1009
+ " <td>0.5</td>\n",
1010
+ " <td>0.5</td>\n",
1011
+ " </tr>\n",
1012
+ " <tr>\n",
1013
+ " <th>3</th>\n",
1014
+ " <td>040e15f562a2</td>\n",
1015
+ " <td>0.5</td>\n",
1016
+ " <td>0.5</td>\n",
1017
+ " </tr>\n",
1018
+ " <tr>\n",
1019
+ " <th>4</th>\n",
1020
+ " <td>046e85c7cc7f</td>\n",
1021
+ " <td>0.5</td>\n",
1022
+ " <td>0.5</td>\n",
1023
+ " </tr>\n",
1024
+ " </tbody>\n",
1025
+ "</table>\n",
1026
+ "</div>"
1027
+ ],
1028
+ "text/plain": [
1029
+ " Id class_0 class_1\n",
1030
+ "0 00eed32682bb 0.5 0.5\n",
1031
+ "1 010ebe33f668 0.5 0.5\n",
1032
+ "2 02fa521e1838 0.5 0.5\n",
1033
+ "3 040e15f562a2 0.5 0.5\n",
1034
+ "4 046e85c7cc7f 0.5 0.5"
1035
+ ]
1036
+ },
1037
+ "execution_count": 24,
1038
+ "metadata": {},
1039
+ "output_type": "execute_result"
1040
+ }
1041
+ ],
1042
+ "source": [
1043
+ "submission_df = pd.read_csv('submission.csv')\n",
1044
+ "submission_df"
1045
+ ]
1046
+ }
1047
+ ],
1048
+ "metadata": {
1049
+ "kernelspec": {
1050
+ "display_name": "Python 3",
1051
+ "language": "python",
1052
+ "name": "python3"
1053
+ },
1054
+ "language_info": {
1055
+ "codemirror_mode": {
1056
+ "name": "ipython",
1057
+ "version": 3
1058
+ },
1059
+ "file_extension": ".py",
1060
+ "mimetype": "text/x-python",
1061
+ "name": "python",
1062
+ "nbconvert_exporter": "python",
1063
+ "pygments_lexer": "ipython3",
1064
+ "version": "3.10.12"
1065
+ },
1066
+ "papermill": {
1067
+ "default_parameters": {},
1068
+ "duration": 3089.637665,
1069
+ "end_time": "2023-07-21T16:44:03.814098",
1070
+ "environment_variables": {},
1071
+ "exception": null,
1072
+ "input_path": "__notebook__.ipynb",
1073
+ "output_path": "__notebook__.ipynb",
1074
+ "parameters": {},
1075
+ "start_time": "2023-07-21T15:52:34.176433",
1076
+ "version": "2.4.0"
1077
+ },
1078
+ "widgets": {
1079
+ "application/vnd.jupyter.widget-state+json": {
1080
+ "state": {
1081
+ "0eb183199e764348bf581e2ffce05ecf": {
1082
+ "model_module": "@jupyter-widgets/base",
1083
+ "model_module_version": "1.2.0",
1084
+ "model_name": "LayoutModel",
1085
+ "state": {
1086
+ "_model_module": "@jupyter-widgets/base",
1087
+ "_model_module_version": "1.2.0",
1088
+ "_model_name": "LayoutModel",
1089
+ "_view_count": null,
1090
+ "_view_module": "@jupyter-widgets/base",
1091
+ "_view_module_version": "1.2.0",
1092
+ "_view_name": "LayoutView",
1093
+ "align_content": null,
1094
+ "align_items": null,
1095
+ "align_self": null,
1096
+ "border": null,
1097
+ "bottom": null,
1098
+ "display": null,
1099
+ "flex": null,
1100
+ "flex_flow": null,
1101
+ "grid_area": null,
1102
+ "grid_auto_columns": null,
1103
+ "grid_auto_flow": null,
1104
+ "grid_auto_rows": null,
1105
+ "grid_column": null,
1106
+ "grid_gap": null,
1107
+ "grid_row": null,
1108
+ "grid_template_areas": null,
1109
+ "grid_template_columns": null,
1110
+ "grid_template_rows": null,
1111
+ "height": null,
1112
+ "justify_content": null,
1113
+ "justify_items": null,
1114
+ "left": null,
1115
+ "margin": null,
1116
+ "max_height": null,
1117
+ "max_width": null,
1118
+ "min_height": null,
1119
+ "min_width": null,
1120
+ "object_fit": null,
1121
+ "object_position": null,
1122
+ "order": null,
1123
+ "overflow": null,
1124
+ "overflow_x": null,
1125
+ "overflow_y": null,
1126
+ "padding": null,
1127
+ "right": null,
1128
+ "top": null,
1129
+ "visibility": null,
1130
+ "width": null
1131
+ }
1132
+ },
1133
+ "1486711fdccc430bb8f19ffe0003cdf5": {
1134
+ "model_module": "@jupyter-widgets/controls",
1135
+ "model_module_version": "1.5.0",
1136
+ "model_name": "HBoxModel",
1137
+ "state": {
1138
+ "_dom_classes": [],
1139
+ "_model_module": "@jupyter-widgets/controls",
1140
+ "_model_module_version": "1.5.0",
1141
+ "_model_name": "HBoxModel",
1142
+ "_view_count": null,
1143
+ "_view_module": "@jupyter-widgets/controls",
1144
+ "_view_module_version": "1.5.0",
1145
+ "_view_name": "HBoxView",
1146
+ "box_style": "",
1147
+ "children": [
1148
+ "IPY_MODEL_f0a7e9ffa8394babae8a609317de1970",
1149
+ "IPY_MODEL_d75596a40df548e8a3c16233b0c56d18",
1150
+ "IPY_MODEL_ee353ab052bc474ead46fd0f9ed9e203"
1151
+ ],
1152
+ "layout": "IPY_MODEL_b2f3a43388b041edba2a3f9834a28a94"
1153
+ }
1154
+ },
1155
+ "27e595e199224d0d93832f5dd41379e8": {
1156
+ "model_module": "@jupyter-widgets/controls",
1157
+ "model_module_version": "1.5.0",
1158
+ "model_name": "DescriptionStyleModel",
1159
+ "state": {
1160
+ "_model_module": "@jupyter-widgets/controls",
1161
+ "_model_module_version": "1.5.0",
1162
+ "_model_name": "DescriptionStyleModel",
1163
+ "_view_count": null,
1164
+ "_view_module": "@jupyter-widgets/base",
1165
+ "_view_module_version": "1.2.0",
1166
+ "_view_name": "StyleView",
1167
+ "description_width": ""
1168
+ }
1169
+ },
1170
+ "44e788a6798c486db6f6a9d4d910eb9e": {
1171
+ "model_module": "@jupyter-widgets/base",
1172
+ "model_module_version": "1.2.0",
1173
+ "model_name": "LayoutModel",
1174
+ "state": {
1175
+ "_model_module": "@jupyter-widgets/base",
1176
+ "_model_module_version": "1.2.0",
1177
+ "_model_name": "LayoutModel",
1178
+ "_view_count": null,
1179
+ "_view_module": "@jupyter-widgets/base",
1180
+ "_view_module_version": "1.2.0",
1181
+ "_view_name": "LayoutView",
1182
+ "align_content": null,
1183
+ "align_items": null,
1184
+ "align_self": null,
1185
+ "border": null,
1186
+ "bottom": null,
1187
+ "display": null,
1188
+ "flex": null,
1189
+ "flex_flow": null,
1190
+ "grid_area": null,
1191
+ "grid_auto_columns": null,
1192
+ "grid_auto_flow": null,
1193
+ "grid_auto_rows": null,
1194
+ "grid_column": null,
1195
+ "grid_gap": null,
1196
+ "grid_row": null,
1197
+ "grid_template_areas": null,
1198
+ "grid_template_columns": null,
1199
+ "grid_template_rows": null,
1200
+ "height": null,
1201
+ "justify_content": null,
1202
+ "justify_items": null,
1203
+ "left": null,
1204
+ "margin": null,
1205
+ "max_height": null,
1206
+ "max_width": null,
1207
+ "min_height": null,
1208
+ "min_width": null,
1209
+ "object_fit": null,
1210
+ "object_position": null,
1211
+ "order": null,
1212
+ "overflow": null,
1213
+ "overflow_x": null,
1214
+ "overflow_y": null,
1215
+ "padding": null,
1216
+ "right": null,
1217
+ "top": null,
1218
+ "visibility": null,
1219
+ "width": null
1220
+ }
1221
+ },
1222
+ "ae831dbd067b467ab82639cf32c7c94d": {
1223
+ "model_module": "@jupyter-widgets/controls",
1224
+ "model_module_version": "1.5.0",
1225
+ "model_name": "ProgressStyleModel",
1226
+ "state": {
1227
+ "_model_module": "@jupyter-widgets/controls",
1228
+ "_model_module_version": "1.5.0",
1229
+ "_model_name": "ProgressStyleModel",
1230
+ "_view_count": null,
1231
+ "_view_module": "@jupyter-widgets/base",
1232
+ "_view_module_version": "1.2.0",
1233
+ "_view_name": "StyleView",
1234
+ "bar_color": null,
1235
+ "description_width": ""
1236
+ }
1237
+ },
1238
+ "b2f3a43388b041edba2a3f9834a28a94": {
1239
+ "model_module": "@jupyter-widgets/base",
1240
+ "model_module_version": "1.2.0",
1241
+ "model_name": "LayoutModel",
1242
+ "state": {
1243
+ "_model_module": "@jupyter-widgets/base",
1244
+ "_model_module_version": "1.2.0",
1245
+ "_model_name": "LayoutModel",
1246
+ "_view_count": null,
1247
+ "_view_module": "@jupyter-widgets/base",
1248
+ "_view_module_version": "1.2.0",
1249
+ "_view_name": "LayoutView",
1250
+ "align_content": null,
1251
+ "align_items": null,
1252
+ "align_self": null,
1253
+ "border": null,
1254
+ "bottom": null,
1255
+ "display": null,
1256
+ "flex": null,
1257
+ "flex_flow": null,
1258
+ "grid_area": null,
1259
+ "grid_auto_columns": null,
1260
+ "grid_auto_flow": null,
1261
+ "grid_auto_rows": null,
1262
+ "grid_column": null,
1263
+ "grid_gap": null,
1264
+ "grid_row": null,
1265
+ "grid_template_areas": null,
1266
+ "grid_template_columns": null,
1267
+ "grid_template_rows": null,
1268
+ "height": null,
1269
+ "justify_content": null,
1270
+ "justify_items": null,
1271
+ "left": null,
1272
+ "margin": null,
1273
+ "max_height": null,
1274
+ "max_width": null,
1275
+ "min_height": null,
1276
+ "min_width": null,
1277
+ "object_fit": null,
1278
+ "object_position": null,
1279
+ "order": null,
1280
+ "overflow": null,
1281
+ "overflow_x": null,
1282
+ "overflow_y": null,
1283
+ "padding": null,
1284
+ "right": null,
1285
+ "top": null,
1286
+ "visibility": null,
1287
+ "width": null
1288
+ }
1289
+ },
1290
+ "bfd66578853d48eeb16725dccf2b9065": {
1291
+ "model_module": "@jupyter-widgets/controls",
1292
+ "model_module_version": "1.5.0",
1293
+ "model_name": "DescriptionStyleModel",
1294
+ "state": {
1295
+ "_model_module": "@jupyter-widgets/controls",
1296
+ "_model_module_version": "1.5.0",
1297
+ "_model_name": "DescriptionStyleModel",
1298
+ "_view_count": null,
1299
+ "_view_module": "@jupyter-widgets/base",
1300
+ "_view_module_version": "1.2.0",
1301
+ "_view_name": "StyleView",
1302
+ "description_width": ""
1303
+ }
1304
+ },
1305
+ "d75596a40df548e8a3c16233b0c56d18": {
1306
+ "model_module": "@jupyter-widgets/controls",
1307
+ "model_module_version": "1.5.0",
1308
+ "model_name": "FloatProgressModel",
1309
+ "state": {
1310
+ "_dom_classes": [],
1311
+ "_model_module": "@jupyter-widgets/controls",
1312
+ "_model_module_version": "1.5.0",
1313
+ "_model_name": "FloatProgressModel",
1314
+ "_view_count": null,
1315
+ "_view_module": "@jupyter-widgets/controls",
1316
+ "_view_module_version": "1.5.0",
1317
+ "_view_name": "ProgressView",
1318
+ "bar_style": "success",
1319
+ "description": "",
1320
+ "description_tooltip": null,
1321
+ "layout": "IPY_MODEL_0eb183199e764348bf581e2ffce05ecf",
1322
+ "max": 5.0,
1323
+ "min": 0.0,
1324
+ "orientation": "horizontal",
1325
+ "style": "IPY_MODEL_ae831dbd067b467ab82639cf32c7c94d",
1326
+ "value": 5.0
1327
+ }
1328
+ },
1329
+ "ee353ab052bc474ead46fd0f9ed9e203": {
1330
+ "model_module": "@jupyter-widgets/controls",
1331
+ "model_module_version": "1.5.0",
1332
+ "model_name": "HTMLModel",
1333
+ "state": {
1334
+ "_dom_classes": [],
1335
+ "_model_module": "@jupyter-widgets/controls",
1336
+ "_model_module_version": "1.5.0",
1337
+ "_model_name": "HTMLModel",
1338
+ "_view_count": null,
1339
+ "_view_module": "@jupyter-widgets/controls",
1340
+ "_view_module_version": "1.5.0",
1341
+ "_view_name": "HTMLView",
1342
+ "description": "",
1343
+ "description_tooltip": null,
1344
+ "layout": "IPY_MODEL_ee7ef555c90e4457b64ff561bf94a63c",
1345
+ "placeholder": "​",
1346
+ "style": "IPY_MODEL_bfd66578853d48eeb16725dccf2b9065",
1347
+ "value": " 5/5 [43:19&lt;00:00, 520.03s/it]"
1348
+ }
1349
+ },
1350
+ "ee7ef555c90e4457b64ff561bf94a63c": {
1351
+ "model_module": "@jupyter-widgets/base",
1352
+ "model_module_version": "1.2.0",
1353
+ "model_name": "LayoutModel",
1354
+ "state": {
1355
+ "_model_module": "@jupyter-widgets/base",
1356
+ "_model_module_version": "1.2.0",
1357
+ "_model_name": "LayoutModel",
1358
+ "_view_count": null,
1359
+ "_view_module": "@jupyter-widgets/base",
1360
+ "_view_module_version": "1.2.0",
1361
+ "_view_name": "LayoutView",
1362
+ "align_content": null,
1363
+ "align_items": null,
1364
+ "align_self": null,
1365
+ "border": null,
1366
+ "bottom": null,
1367
+ "display": null,
1368
+ "flex": null,
1369
+ "flex_flow": null,
1370
+ "grid_area": null,
1371
+ "grid_auto_columns": null,
1372
+ "grid_auto_flow": null,
1373
+ "grid_auto_rows": null,
1374
+ "grid_column": null,
1375
+ "grid_gap": null,
1376
+ "grid_row": null,
1377
+ "grid_template_areas": null,
1378
+ "grid_template_columns": null,
1379
+ "grid_template_rows": null,
1380
+ "height": null,
1381
+ "justify_content": null,
1382
+ "justify_items": null,
1383
+ "left": null,
1384
+ "margin": null,
1385
+ "max_height": null,
1386
+ "max_width": null,
1387
+ "min_height": null,
1388
+ "min_width": null,
1389
+ "object_fit": null,
1390
+ "object_position": null,
1391
+ "order": null,
1392
+ "overflow": null,
1393
+ "overflow_x": null,
1394
+ "overflow_y": null,
1395
+ "padding": null,
1396
+ "right": null,
1397
+ "top": null,
1398
+ "visibility": null,
1399
+ "width": null
1400
+ }
1401
+ },
1402
+ "f0a7e9ffa8394babae8a609317de1970": {
1403
+ "model_module": "@jupyter-widgets/controls",
1404
+ "model_module_version": "1.5.0",
1405
+ "model_name": "HTMLModel",
1406
+ "state": {
1407
+ "_dom_classes": [],
1408
+ "_model_module": "@jupyter-widgets/controls",
1409
+ "_model_module_version": "1.5.0",
1410
+ "_model_name": "HTMLModel",
1411
+ "_view_count": null,
1412
+ "_view_module": "@jupyter-widgets/controls",
1413
+ "_view_module_version": "1.5.0",
1414
+ "_view_name": "HTMLView",
1415
+ "description": "",
1416
+ "description_tooltip": null,
1417
+ "layout": "IPY_MODEL_44e788a6798c486db6f6a9d4d910eb9e",
1418
+ "placeholder": "​",
1419
+ "style": "IPY_MODEL_27e595e199224d0d93832f5dd41379e8",
1420
+ "value": "100%"
1421
+ }
1422
+ }
1423
+ },
1424
+ "version_major": 2,
1425
+ "version_minor": 0
1426
+ }
1427
+ }
1428
+ },
1429
+ "nbformat": 4,
1430
+ "nbformat_minor": 5
1431
+ }