Jensen-holm commited on
Commit
895afd0
1 Parent(s): 1720a63

creating a neural network for the mens data, predicting the win column

Browse files

with 10k epochs, trained on my m1 macbook with metal gpu acceleration.
Next we need to compare this model performance against a baseline model
that just blindly picks the higher chalk seed to win.

models/{scoreDist30k.pth → nn10k.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d11fcf0f9b0ea5b93de0cbfbbaed4447f48ebc03c53be688981c3ceddbc287f7
3
- size 39998
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72062bda8133544b2241172907c3614855527a51357bebddaac0e46a7a9ea29f
3
+ size 18898
src/nn.ipynb CHANGED
@@ -25,7 +25,7 @@
25
  "name": "stderr",
26
  "output_type": "stream",
27
  "text": [
28
- "/var/folders/v8/0hd98b512cn3ms2rz146k7jw0000gn/T/ipykernel_23752/685274063.py:1: DtypeWarning: Columns (481,482,483) have mixed types. Specify dtype option on import or set low_memory=False.\n",
29
  " detailed_games_df = pd.read_csv(\n"
30
  ]
31
  },
@@ -66,6 +66,13 @@
66
  "wmns_games_df = detailed_games_df[detailed_games_df[\"League\"] == \"W\"]"
67
  ]
68
  },
 
 
 
 
 
 
 
69
  {
70
  "cell_type": "code",
71
  "execution_count": 4,
@@ -78,6 +85,11 @@
78
  " \"FGMDiff mean reg\",\n",
79
  " \"FGM3Diff mean reg\",\n",
80
  " \"TODiff mean reg\",\n",
 
 
 
 
 
81
  "]\n",
82
  "\n",
83
  "target_cols = [\"Win\"]"
@@ -108,52 +120,70 @@
108
  },
109
  {
110
  "cell_type": "code",
111
- "execution_count": 6,
112
  "metadata": {},
113
  "outputs": [],
114
  "source": [
115
  "# convert data to tensor objects and register to device\n",
116
- "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  "\n",
118
  "MX_train_T = torch.tensor(\n",
119
- " MX_train.values,\n",
120
- " dtype=float,\n",
121
  ").to(DEVICE)\n",
122
  "\n",
123
  "MX_test_T = torch.tensor(\n",
124
- " MX_test.values,\n",
125
- " dtype=float,\n",
126
  ").to(DEVICE)\n",
127
  "\n",
128
  "My_train_T = torch.tensor(\n",
129
- " My_train.values,\n",
130
- " dtype=float,\n",
131
  ").to(DEVICE)\n",
132
  "\n",
133
  "My_test_T = torch.tensor(\n",
134
- " My_test.values,\n",
135
- " dtype=float,\n",
136
  ").to(DEVICE)\n",
137
  "\n",
138
- "# same for womens data\n",
139
- "Wy_test_T = torch.tensor(\n",
140
- " Wy_test.values,\n",
141
- " dtype=float,\n",
142
  ").to(DEVICE)\n",
143
  "\n",
144
- "Wy_test_T = torch.tensor(\n",
145
- " Wy_test.values,\n",
146
- " dtype=float,\n",
147
  ").to(DEVICE)\n",
148
  "\n",
149
- "Wy_test_T = torch.tensor(\n",
150
- " My_test.values,\n",
151
- " dtype=float,\n",
152
  ").to(DEVICE)\n",
153
  "\n",
154
  "Wy_test_T = torch.tensor(\n",
155
  " Wy_test.values,\n",
156
- " dtype=float,\n",
157
  ").to(DEVICE)"
158
  ]
159
  },
@@ -168,7 +198,7 @@
168
  },
169
  {
170
  "cell_type": "code",
171
- "execution_count": 9,
172
  "metadata": {},
173
  "outputs": [],
174
  "source": [
@@ -177,13 +207,13 @@
177
  "class NiglNN(nn.Module):\n",
178
  " def __init__(self):\n",
179
  " super().__init__()\n",
180
- " self.activation_func = nn.ReLU()\n",
181
  " self.layer1 = nn.Linear(num_features, 64) \n",
182
  " self.layer2 = nn.Linear(64, 32)\n",
183
  " self.layer3 = nn.Linear(32, 16)\n",
184
  " self.layer4 = nn.Linear(16, 8)\n",
185
  " self.layer5 = nn.Linear(8, 4)\n",
186
- " self.layer5 = nn.Linear(4, 1)\n",
187
  "\n",
188
  " def forward(self, x: torch.Tensor):\n",
189
  " x = self.layer1(x)\n",
@@ -195,60 +225,93 @@
195
  " x = self.layer4(x)\n",
196
  " x = self.activation_func(x)\n",
197
  " x = self.layer5(x)\n",
 
 
 
198
  " return x\n"
199
  ]
200
  },
201
  {
202
  "cell_type": "code",
203
- "execution_count": 12,
204
  "metadata": {},
205
  "outputs": [
206
  {
207
- "ename": "RuntimeError",
208
- "evalue": "mat1 and mat2 must have the same dtype, but got Double and Float",
209
- "output_type": "error",
210
- "traceback": [
211
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
212
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
213
- "Cell \u001b[0;32mIn[12], line 11\u001b[0m\n\u001b[1;32m 5\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m optim\u001b[38;5;241m.\u001b[39mAdam(\n\u001b[1;32m 6\u001b[0m lr\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.001\u001b[39m,\n\u001b[1;32m 7\u001b[0m params\u001b[38;5;241m=\u001b[39mnigl1k\u001b[38;5;241m.\u001b[39mparameters(),\n\u001b[1;32m 8\u001b[0m )\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epochs):\n\u001b[0;32m---> 11\u001b[0m pred \u001b[38;5;241m=\u001b[39m \u001b[43mnigl1k\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMX_train_T\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 15\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_fn(pred, My_test_T) \n",
214
- "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
215
- "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
216
- "Cell \u001b[0;32mIn[9], line 15\u001b[0m, in \u001b[0;36mNiglNN.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[0;32m---> 15\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayer1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_func(x)\n\u001b[1;32m 17\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer2(x)\n",
217
- "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
218
- "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
219
- "File \u001b[0;32m/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
220
- "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 must have the same dtype, but got Double and Float"
221
  ]
222
  }
223
  ],
224
  "source": [
225
  "# mens training loop\n",
226
- "epochs = 1_000\n",
227
- "nigl1k = NiglNN()\n",
228
- "loss_fn = nn.MSELoss()\n",
 
 
229
  "optimizer = optim.Adam(\n",
230
  " lr=0.001,\n",
231
- " params=nigl1k.parameters(),\n",
232
  ")\n",
233
  "\n",
234
- "for i in range(epochs):\n",
235
- " pred = nigl1k(MX_train_T)\n",
236
  " optimizer.zero_grad()\n",
237
- "\n",
238
- " loss = loss_fn(pred, My_test_T) \n",
239
  " loss.backward()\n",
240
- "\n",
241
  " optimizer.step()\n",
242
- " if i % epochs == 0:\n",
243
- " print(f\"[{i} / {epochs}] loss = {loss.item()}\")\n"
 
244
  ]
245
  },
246
  {
247
  "cell_type": "code",
248
- "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  "metadata": {},
250
  "outputs": [],
251
- "source": []
 
 
 
 
 
 
 
 
252
  }
253
  ],
254
  "metadata": {
 
25
  "name": "stderr",
26
  "output_type": "stream",
27
  "text": [
28
+ "/var/folders/v8/0hd98b512cn3ms2rz146k7jw0000gn/T/ipykernel_41770/685274063.py:1: DtypeWarning: Columns (481,482,483) have mixed types. Specify dtype option on import or set low_memory=False.\n",
29
  " detailed_games_df = pd.read_csv(\n"
30
  ]
31
  },
 
66
  "wmns_games_df = detailed_games_df[detailed_games_df[\"League\"] == \"W\"]"
67
  ]
68
  },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "## Define Features, Targets, and register data on device"
74
+ ]
75
+ },
76
  {
77
  "cell_type": "code",
78
  "execution_count": 4,
 
85
  " \"FGMDiff mean reg\",\n",
86
  " \"FGM3Diff mean reg\",\n",
87
  " \"TODiff mean reg\",\n",
88
+ "\n",
89
+ " \"OppScore mean reg\",\n",
90
+ " \"OppFGM mean reg\",\n",
91
+ " \"OppFGM3 mean reg\",\n",
92
+ " \"OppTO mean reg\",\n",
93
  "]\n",
94
  "\n",
95
  "target_cols = [\"Win\"]"
 
120
  },
121
  {
122
  "cell_type": "code",
123
+ "execution_count": null,
124
  "metadata": {},
125
  "outputs": [],
126
  "source": [
127
  "# convert data to tensor objects and register to device\n",
128
+ "# DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
129
+ "\n",
130
+ "def get_device() -> str:\n",
131
+ " if torch.cuda.is_available():\n",
132
+ " return \"cuda\"\n",
133
+ " if torch.backends.mps.is_available():\n",
134
+ " return \"mps\"\n",
135
+ " return \"cpu\"\n",
136
+ "\n",
137
+ "DEVICE = get_device()\n",
138
+ "print(DEVICE)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 6,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
  "\n",
148
  "MX_train_T = torch.tensor(\n",
149
+ " MX_train.astype(float).values,\n",
150
+ " dtype=torch.float32,\n",
151
  ").to(DEVICE)\n",
152
  "\n",
153
  "MX_test_T = torch.tensor(\n",
154
+ " MX_test.astype(float).values,\n",
155
+ " dtype=torch.float32,\n",
156
  ").to(DEVICE)\n",
157
  "\n",
158
  "My_train_T = torch.tensor(\n",
159
+ " My_train.astype(float).values,\n",
160
+ " dtype=torch.float32,\n",
161
  ").to(DEVICE)\n",
162
  "\n",
163
  "My_test_T = torch.tensor(\n",
164
+ " My_test.astype(float).values,\n",
165
+ " dtype=torch.float32,\n",
166
  ").to(DEVICE)\n",
167
  "\n",
168
+ "# # same for womens data\n",
169
+ "WX_train_T = torch.tensor(\n",
170
+ " WX_train.values,\n",
171
+ " dtype=torch.float32,\n",
172
  ").to(DEVICE)\n",
173
  "\n",
174
+ "WX_test_T = torch.tensor(\n",
175
+ " WX_test.values,\n",
176
+ " dtype=torch.float32,\n",
177
  ").to(DEVICE)\n",
178
  "\n",
179
+ "Wy_train_T = torch.tensor(\n",
180
+ " Wy_train.values,\n",
181
+ " dtype=torch.float32,\n",
182
  ").to(DEVICE)\n",
183
  "\n",
184
  "Wy_test_T = torch.tensor(\n",
185
  " Wy_test.values,\n",
186
+ " dtype=torch.float32,\n",
187
  ").to(DEVICE)"
188
  ]
189
  },
 
198
  },
199
  {
200
  "cell_type": "code",
201
+ "execution_count": 8,
202
  "metadata": {},
203
  "outputs": [],
204
  "source": [
 
207
  "class NiglNN(nn.Module):\n",
208
  " def __init__(self):\n",
209
  " super().__init__()\n",
210
+ " self.activation_func = nn.Sigmoid()\n",
211
  " self.layer1 = nn.Linear(num_features, 64) \n",
212
  " self.layer2 = nn.Linear(64, 32)\n",
213
  " self.layer3 = nn.Linear(32, 16)\n",
214
  " self.layer4 = nn.Linear(16, 8)\n",
215
  " self.layer5 = nn.Linear(8, 4)\n",
216
+ " self.layer6 = nn.Linear(4, 1)\n",
217
  "\n",
218
  " def forward(self, x: torch.Tensor):\n",
219
  " x = self.layer1(x)\n",
 
225
  " x = self.layer4(x)\n",
226
  " x = self.activation_func(x)\n",
227
  " x = self.layer5(x)\n",
228
+ " x = self.activation_func(x)\n",
229
+ " x = self.layer6(x)\n",
230
+ " x = self.activation_func(x)\n",
231
  " return x\n"
232
  ]
233
  },
234
  {
235
  "cell_type": "code",
236
+ "execution_count": 10,
237
  "metadata": {},
238
  "outputs": [
239
  {
240
+ "name": "stdout",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "[1000 / 10000] Binary Cross Entropy: 0.6770758628845215\n",
244
+ "[2000 / 10000] Binary Cross Entropy: 0.6671037077903748\n",
245
+ "[3000 / 10000] Binary Cross Entropy: 0.6648934483528137\n",
246
+ "[4000 / 10000] Binary Cross Entropy: 0.6640341281890869\n",
247
+ "[5000 / 10000] Binary Cross Entropy: 0.663619875907898\n",
248
+ "[6000 / 10000] Binary Cross Entropy: 0.6633755564689636\n",
249
+ "[7000 / 10000] Binary Cross Entropy: 0.6631807088851929\n",
250
+ "[8000 / 10000] Binary Cross Entropy: 0.663043200969696\n",
251
+ "[9000 / 10000] Binary Cross Entropy: 0.6629269123077393\n",
252
+ "[10000 / 10000] Binary Cross Entropy: 0.6629060506820679\n"
 
253
  ]
254
  }
255
  ],
256
  "source": [
257
  "# mens training loop\n",
258
+ "torch.manual_seed(2)\n",
259
+ "\n",
260
+ "epochs = 10_000\n",
261
+ "nigl10k = NiglNN().to(DEVICE)\n",
262
+ "loss_fn = nn.BCEWithLogitsLoss()\n",
263
  "optimizer = optim.Adam(\n",
264
  " lr=0.001,\n",
265
+ " params=nigl10k.parameters(),\n",
266
  ")\n",
267
  "\n",
268
+ "for epoch in range(1, epochs + 1):\n",
 
269
  " optimizer.zero_grad()\n",
270
+ " pred = nigl10k(MX_train_T)\n",
271
+ " loss = loss_fn(pred, My_train_T) \n",
272
  " loss.backward()\n",
 
273
  " optimizer.step()\n",
274
+ "\n",
275
+ " if epoch % 1_000 == 0:\n",
276
+ " print(f\"[{epoch} / {epochs}] Binary Cross Entropy: {loss.item()}\")\n"
277
  ]
278
  },
279
  {
280
  "cell_type": "code",
281
+ "execution_count": 11,
282
+ "metadata": {},
283
+ "outputs": [
284
+ {
285
+ "name": "stdout",
286
+ "output_type": "stream",
287
+ "text": [
288
+ "Binary Cross Entropy: 0.6655928492546082\n"
289
+ ]
290
+ }
291
+ ],
292
+ "source": [
293
+ "nigl10k.eval()\n",
294
+ "\n",
295
+ "with torch.no_grad():\n",
296
+ " pred = nigl10k(MX_test_T)\n",
297
+ " loss = loss_fn(pred, My_test_T)\n",
298
+ " print(f\"Binary Cross Entropy: {loss.item()}\")"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 12,
304
  "metadata": {},
305
  "outputs": [],
306
+ "source": [
307
+ "# save model\n",
308
+ "MODEL_DIR = os.path.join(\"..\", \"models\")\n",
309
+ "\n",
310
+ "torch.save(\n",
311
+ " nigl10k,\n",
312
+ " os.path.join(MODEL_DIR, \"nn10k.pth\"),\n",
313
+ ")"
314
+ ]
315
  }
316
  ],
317
  "metadata": {