Lederback commited on
Commit
dc92a80
1 Parent(s): c158e96

Upload modelo_treino.ipynb

Browse files
Files changed (1) hide show
  1. modelo_treino.ipynb +823 -0
modelo_treino.ipynb ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "tbu_ucRSo5zT"
7
+ },
8
+ "source": [
9
+ "The following is an example of how to utilize our Sen1Floods11 dataset for training a FCNN. In this example, we train and validate on hand-labeled chips of flood events. However, our dataset includes several other options that are detailed in the README. To replace the dataset, as outlined further below, simply replace the train, test, and validation split csv's, and download the corresponding dataset."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {
15
+ "id": "9TQtMrI_VhKk"
16
+ },
17
+ "source": [
18
+ "Authenticate Google Cloud Platform. Note that to run this code, you must connect your notebook runtime to a GPU."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {
25
+ "id": "qCEt8eNtU9Zm"
26
+ },
27
+ "outputs": [],
28
+ "source": [
29
+ "from google.colab import auth\n",
30
+ "auth.authenticate_user()\n",
31
+ "\n",
32
+ "!curl https://sdk.cloud.google.com | bash\n",
33
+ "\n",
34
+ "!gcloud init"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "YkUEnwXQVy4k"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "!echo \"deb http://packages.cloud.google.com/apt gcsfuse-bionic main\" > /etc/apt/sources.list.d/gcsfuse.list\n",
46
+ "!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -\n",
47
+ "!apt -qq update\n",
48
+ "!apt -qq install gcsfuse"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {
54
+ "id": "vXGTA6vHVyJX"
55
+ },
56
+ "source": [
57
+ "Install RasterIO"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {
64
+ "id": "NLlVutLzV_pZ"
65
+ },
66
+ "outputs": [],
67
+ "source": [
68
+ "!pip install rasterio"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {
74
+ "id": "hLqL9C2Rg6eB"
75
+ },
76
+ "source": [
77
+ "Define a model checkpoint folder, for storing network checkpoints during training"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {
84
+ "id": "yLlIhE-Hg-Ym"
85
+ },
86
+ "outputs": [],
87
+ "source": [
88
+ "%cd /home\n",
89
+ "!sudo mkdir checkpoints"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {
95
+ "id": "mwrDM4AjVnbU"
96
+ },
97
+ "source": [
98
+ "Download train, test, and validation splits for both flood water. To download different train, test, and validation splits, simply replace these paths with the path to a csv containing the desired splits."
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {
105
+ "id": "RFLsGwdRWuO4"
106
+ },
107
+ "outputs": [],
108
+ "source": [
109
+ "!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_train_data.csv .\n",
110
+ "!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_test_data.csv .\n",
111
+ "!gsutil cp gs://sen1floods11/v1.1/splits/flood_handlabeled/flood_valid_data.csv ."
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "metadata": {
117
+ "id": "ZCAXpuKVW3eV"
118
+ },
119
+ "source": [
120
+ "Download raw train, test, and validation data. In this example, we are downloading train, test, and validation data of flood images which are hand labeled. However, you can simply replace these paths with whichever dataset you would like to use - further documentation of the Sen1Floods11 dataset and organization is available in the README."
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {
127
+ "id": "ahAWnrSFW53S"
128
+ },
129
+ "outputs": [],
130
+ "source": [
131
+ "!sudo mkdir files\n",
132
+ "!sudo mkdir files/S1\n",
133
+ "!sudo mkdir files/Labels\n",
134
+ "\n",
135
+ "!gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/S1Hand files/S1\n",
136
+ "!gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand files/Labels"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {
142
+ "id": "_46CazV3XSCD"
143
+ },
144
+ "source": [
145
+ "Define model training hyperparameters"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {
152
+ "id": "fNYQywdWXeLM"
153
+ },
154
+ "outputs": [],
155
+ "source": [
156
+ "LR = 5e-4\n",
157
+ "EPOCHS = 100\n",
158
+ "EPOCHS_PER_UPDATE = 1\n",
159
+ "RUNNAME = \"Sen1Floods11\""
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "markdown",
164
+ "metadata": {
165
+ "id": "W9FJmTnZXjxj"
166
+ },
167
+ "source": [
168
+ "Define functions to process and augment training and testing images"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {
175
+ "id": "mBkfav0Eajqg"
176
+ },
177
+ "outputs": [],
178
+ "source": [
179
+ "import torch\n",
180
+ "from torchvision import transforms\n",
181
+ "import torchvision.transforms.functional as F\n",
182
+ "import random\n",
183
+ "from PIL import Image\n",
184
+ "\n",
185
+ "class InMemoryDataset(torch.utils.data.Dataset):\n",
186
+ "\n",
187
+ " def __init__(self, data_list, preprocess_func):\n",
188
+ " self.data_list = data_list\n",
189
+ " self.preprocess_func = preprocess_func\n",
190
+ "\n",
191
+ " def __getitem__(self, i):\n",
192
+ " return self.preprocess_func(self.data_list[i])\n",
193
+ "\n",
194
+ " def __len__(self):\n",
195
+ " return len(self.data_list)\n",
196
+ "\n",
197
+ "\n",
198
+ "def processAndAugment(data):\n",
199
+ " (x,y) = data\n",
200
+ " im,label = x.copy(), y.copy()\n",
201
+ "\n",
202
+ " # convert to PIL for easier transforms\n",
203
+ " im1 = Image.fromarray(im[0])\n",
204
+ " im2 = Image.fromarray(im[1])\n",
205
+ " label = Image.fromarray(label.squeeze())\n",
206
+ "\n",
207
+ " # Get params for random transforms\n",
208
+ " i, j, h, w = transforms.RandomCrop.get_params(im1, (256, 256))\n",
209
+ "\n",
210
+ " im1 = F.crop(im1, i, j, h, w)\n",
211
+ " im2 = F.crop(im2, i, j, h, w)\n",
212
+ " label = F.crop(label, i, j, h, w)\n",
213
+ " if random.random() > 0.5:\n",
214
+ " im1 = F.hflip(im1)\n",
215
+ " im2 = F.hflip(im2)\n",
216
+ " label = F.hflip(label)\n",
217
+ " if random.random() > 0.5:\n",
218
+ " im1 = F.vflip(im1)\n",
219
+ " im2 = F.vflip(im2)\n",
220
+ " label = F.vflip(label)\n",
221
+ "\n",
222
+ " norm = transforms.Normalize([0.6851, 0.5235], [0.0820, 0.1102])\n",
223
+ " im = torch.stack([transforms.ToTensor()(im1).squeeze(), transforms.ToTensor()(im2).squeeze()])\n",
224
+ " im = norm(im)\n",
225
+ " label = transforms.ToTensor()(label).squeeze()\n",
226
+ " if torch.sum(label.gt(.003) * label.lt(.004)):\n",
227
+ " label *= 255\n",
228
+ " label = label.round()\n",
229
+ "\n",
230
+ " return im, label\n",
231
+ "\n",
232
+ "\n",
233
+ "def processTestIm(data):\n",
234
+ " (x,y) = data\n",
235
+ " im,label = x.copy(), y.copy()\n",
236
+ " norm = transforms.Normalize([0.6851, 0.5235], [0.0820, 0.1102])\n",
237
+ "\n",
238
+ " # convert to PIL for easier transforms\n",
239
+ " im_c1 = Image.fromarray(im[0]).resize((512,512))\n",
240
+ " im_c2 = Image.fromarray(im[1]).resize((512,512))\n",
241
+ " label = Image.fromarray(label.squeeze()).resize((512,512))\n",
242
+ "\n",
243
+ " im_c1s = [F.crop(im_c1, 0, 0, 256, 256), F.crop(im_c1, 0, 256, 256, 256),\n",
244
+ " F.crop(im_c1, 256, 0, 256, 256), F.crop(im_c1, 256, 256, 256, 256)]\n",
245
+ " im_c2s = [F.crop(im_c2, 0, 0, 256, 256), F.crop(im_c2, 0, 256, 256, 256),\n",
246
+ " F.crop(im_c2, 256, 0, 256, 256), F.crop(im_c2, 256, 256, 256, 256)]\n",
247
+ " labels = [F.crop(label, 0, 0, 256, 256), F.crop(label, 0, 256, 256, 256),\n",
248
+ " F.crop(label, 256, 0, 256, 256), F.crop(label, 256, 256, 256, 256)]\n",
249
+ "\n",
250
+ " ims = [torch.stack((transforms.ToTensor()(x).squeeze(),\n",
251
+ " transforms.ToTensor()(y).squeeze()))\n",
252
+ " for (x,y) in zip(im_c1s, im_c2s)]\n",
253
+ "\n",
254
+ " ims = [norm(im) for im in ims]\n",
255
+ " ims = torch.stack(ims)\n",
256
+ "\n",
257
+ " labels = [(transforms.ToTensor()(label).squeeze()) for label in labels]\n",
258
+ " labels = torch.stack(labels)\n",
259
+ "\n",
260
+ " if torch.sum(labels.gt(.003) * labels.lt(.004)):\n",
261
+ " labels *= 255\n",
262
+ " labels = labels.round()\n",
263
+ "\n",
264
+ " return ims, labels"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "markdown",
269
+ "metadata": {
270
+ "id": "uzmZIRuoeAuJ"
271
+ },
272
+ "source": [
273
+ "Load *flood water* train, test, and validation data from splits. In this example, this is the data we will use to train our model."
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {
280
+ "id": "rQUnYCIBeG21"
281
+ },
282
+ "outputs": [],
283
+ "source": [
284
+ "from time import time\n",
285
+ "import csv\n",
286
+ "import os\n",
287
+ "import numpy as np\n",
288
+ "import rasterio\n",
289
+ "\n",
290
+ "def getArrFlood(fname):\n",
291
+ " return rasterio.open(fname).read()\n",
292
+ "\n",
293
+ "def download_flood_water_data_from_list(l):\n",
294
+ " i = 0\n",
295
+ " tot_nan = 0\n",
296
+ " tot_good = 0\n",
297
+ " flood_data = []\n",
298
+ " for (im_fname, mask_fname) in l:\n",
299
+ " if not os.path.exists(os.path.join(\"files/\", im_fname)):\n",
300
+ " continue\n",
301
+ " arr_x = np.nan_to_num(getArrFlood(os.path.join(\"files/\", im_fname)))\n",
302
+ " arr_y = getArrFlood(os.path.join(\"files/\", mask_fname))\n",
303
+ " arr_y[arr_y == -1] = 255\n",
304
+ "\n",
305
+ " arr_x = np.clip(arr_x, -50, 1)\n",
306
+ " arr_x = (arr_x + 50) / 51\n",
307
+ "\n",
308
+ " if i % 100 == 0:\n",
309
+ " print(im_fname, mask_fname)\n",
310
+ " i += 1\n",
311
+ " flood_data.append((arr_x,arr_y))\n",
312
+ "\n",
313
+ " return flood_data\n",
314
+ "\n",
315
+ "def load_flood_train_data(input_root, label_root):\n",
316
+ " fname = \"flood_train_data.csv\"\n",
317
+ " training_files = []\n",
318
+ " with open(fname) as f:\n",
319
+ " for line in csv.reader(f):\n",
320
+ " training_files.append(tuple((input_root+line[0], label_root+line[1])))\n",
321
+ "\n",
322
+ " return download_flood_water_data_from_list(training_files)\n",
323
+ "\n",
324
+ "def load_flood_valid_data(input_root, label_root):\n",
325
+ " fname = \"flood_valid_data.csv\"\n",
326
+ " validation_files = []\n",
327
+ " with open(fname) as f:\n",
328
+ " for line in csv.reader(f):\n",
329
+ " validation_files.append(tuple((input_root+line[0], label_root+line[1])))\n",
330
+ "\n",
331
+ " return download_flood_water_data_from_list(validation_files)\n",
332
+ "\n",
333
+ "def load_flood_test_data(input_root, label_root):\n",
334
+ " fname = \"flood_test_data.csv\"\n",
335
+ " testing_files = []\n",
336
+ " with open(fname) as f:\n",
337
+ " for line in csv.reader(f):\n",
338
+ " testing_files.append(tuple((input_root+line[0], label_root+line[1])))\n",
339
+ "\n",
340
+ " return download_flood_water_data_from_list(testing_files)"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "markdown",
345
+ "metadata": {
346
+ "id": "cFp9jrHYfOUh"
347
+ },
348
+ "source": [
349
+ "Load training data and validation data. Note that here, we have chosen to train and validate our model on flood data. However, you can simply replace the load function call with one of the options defined above to load a different dataset."
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {
356
+ "id": "ZcqPlsjBffXx"
357
+ },
358
+ "outputs": [],
359
+ "source": [
360
+ "train_data = load_flood_train_data('S1/', 'Labels/')\n",
361
+ "train_dataset = InMemoryDataset(train_data, processAndAugment)\n",
362
+ "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, sampler=None,\n",
363
+ " batch_sampler=None, num_workers=0, collate_fn=None,\n",
364
+ " pin_memory=True, drop_last=False, timeout=0,\n",
365
+ " worker_init_fn=None)\n",
366
+ "train_iter = iter(train_loader)\n",
367
+ "\n",
368
+ "valid_data = load_flood_valid_data('S1/', 'Labels/')\n",
369
+ "valid_dataset = InMemoryDataset(valid_data, processTestIm)\n",
370
+ "valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=4, shuffle=True, sampler=None,\n",
371
+ " batch_sampler=None, num_workers=0, collate_fn=lambda x: (torch.cat([a[0] for a in x], 0), torch.cat([a[1] for a in x], 0)),\n",
372
+ " pin_memory=True, drop_last=False, timeout=0,\n",
373
+ " worker_init_fn=None)\n",
374
+ "valid_iter = iter(valid_loader)"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "markdown",
379
+ "metadata": {
380
+ "id": "i3aAhUi2fp7M"
381
+ },
382
+ "source": [
383
+ "Define the network. For our purposes, we use ResNet50. However, if you wish to test a different model framework, optimizer, or loss function you can simply replace those here."
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": null,
389
+ "metadata": {
390
+ "id": "5cp4uXI1f9dr"
391
+ },
392
+ "outputs": [],
393
+ "source": [
394
+ "import torch\n",
395
+ "import torchvision.models as models\n",
396
+ "import torch.nn as nn\n",
397
+ "\n",
398
+ "net = models.segmentation.fcn_resnet50(pretrained=False, num_classes=2, pretrained_backbone=False)\n",
399
+ "net.backbone.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)\n",
400
+ "\n",
401
+ "criterion = nn.CrossEntropyLoss(weight=torch.tensor([1,8]).float().cuda(), ignore_index=255)\n",
402
+ "optimizer = torch.optim.AdamW(net.parameters(),lr=LR)\n",
403
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, len(train_loader) * 10, T_mult=2, eta_min=0, last_epoch=-1)\n",
404
+ "\n",
405
+ "def convertBNtoGN(module, num_groups=16):\n",
406
+ " if isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d):\n",
407
+ " return nn.GroupNorm(num_groups, module.num_features,\n",
408
+ " eps=module.eps, affine=module.affine)\n",
409
+ " if module.affine:\n",
410
+ " mod.weight.data = module.weight.data.clone().detach()\n",
411
+ " mod.bias.data = module.bias.data.clone().detach()\n",
412
+ "\n",
413
+ " for name, child in module.named_children():\n",
414
+ " module.add_module(name, convertBNtoGN(child, num_groups=num_groups))\n",
415
+ "\n",
416
+ " return module\n",
417
+ "\n",
418
+ "net = convertBNtoGN(net)"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "markdown",
423
+ "metadata": {
424
+ "id": "g_Sy3ALGgQjf"
425
+ },
426
+ "source": [
427
+ "Define assessment metrics. For our purposes, we use overall accuracy and mean intersection over union. However, we also include functions for calculating true positives, false positives, true negatives, and false negatives."
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": null,
433
+ "metadata": {
434
+ "id": "bwxC-fVBgUIb"
435
+ },
436
+ "outputs": [],
437
+ "source": [
438
+ "def computeIOU(output, target):\n",
439
+ " output = torch.argmax(output, dim=1).flatten()\n",
440
+ " target = target.flatten()\n",
441
+ "\n",
442
+ " no_ignore = target.ne(255).cuda()\n",
443
+ " output = output.masked_select(no_ignore)\n",
444
+ " target = target.masked_select(no_ignore)\n",
445
+ " intersection = torch.sum(output * target)\n",
446
+ " union = torch.sum(target) + torch.sum(output) - intersection\n",
447
+ " iou = (intersection + .0000001) / (union + .0000001)\n",
448
+ "\n",
449
+ " if iou != iou:\n",
450
+ " print(\"failed, replacing with 0\")\n",
451
+ " iou = torch.tensor(0).float()\n",
452
+ "\n",
453
+ " return iou\n",
454
+ "\n",
455
+ "def computeAccuracy(output, target):\n",
456
+ " output = torch.argmax(output, dim=1).flatten()\n",
457
+ " target = target.flatten()\n",
458
+ "\n",
459
+ " no_ignore = target.ne(255).cuda()\n",
460
+ " output = output.masked_select(no_ignore)\n",
461
+ " target = target.masked_select(no_ignore)\n",
462
+ " correct = torch.sum(output.eq(target))\n",
463
+ "\n",
464
+ " return correct.float() / len(target)\n",
465
+ "\n",
466
+ "def truePositives(output, target):\n",
467
+ " output = torch.argmax(output, dim=1).flatten()\n",
468
+ " target = target.flatten()\n",
469
+ " no_ignore = target.ne(255).cuda()\n",
470
+ " output = output.masked_select(no_ignore)\n",
471
+ " target = target.masked_select(no_ignore)\n",
472
+ " correct = torch.sum(output * target)\n",
473
+ "\n",
474
+ " return correct\n",
475
+ "\n",
476
+ "def trueNegatives(output, target):\n",
477
+ " output = torch.argmax(output, dim=1).flatten()\n",
478
+ " target = target.flatten()\n",
479
+ " no_ignore = target.ne(255).cuda()\n",
480
+ " output = output.masked_select(no_ignore)\n",
481
+ " target = target.masked_select(no_ignore)\n",
482
+ " output = (output == 0)\n",
483
+ " target = (target == 0)\n",
484
+ " correct = torch.sum(output * target)\n",
485
+ "\n",
486
+ " return correct\n",
487
+ "\n",
488
+ "def falsePositives(output, target):\n",
489
+ " output = torch.argmax(output, dim=1).flatten()\n",
490
+ " target = target.flatten()\n",
491
+ " no_ignore = target.ne(255).cuda()\n",
492
+ " output = output.masked_select(no_ignore)\n",
493
+ " target = target.masked_select(no_ignore)\n",
494
+ " output = (output == 1)\n",
495
+ " target = (target == 0)\n",
496
+ " correct = torch.sum(output * target)\n",
497
+ "\n",
498
+ " return correct\n",
499
+ "\n",
500
+ "def falseNegatives(output, target):\n",
501
+ " output = torch.argmax(output, dim=1).flatten()\n",
502
+ " target = target.flatten()\n",
503
+ " no_ignore = target.ne(255).cuda()\n",
504
+ " output = output.masked_select(no_ignore)\n",
505
+ " target = target.masked_select(no_ignore)\n",
506
+ " output = (output == 0)\n",
507
+ " target = (target == 1)\n",
508
+ " correct = torch.sum(output * target)\n",
509
+ "\n",
510
+ " return correct"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "markdown",
515
+ "metadata": {
516
+ "id": "lun5tGoYgjWX"
517
+ },
518
+ "source": [
519
+ "Define training loop"
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "execution_count": null,
525
+ "metadata": {
526
+ "id": "DubsYZ8GgkxD"
527
+ },
528
+ "outputs": [],
529
+ "source": [
530
+ "training_losses = []\n",
531
+ "training_accuracies = []\n",
532
+ "training_ious = []\n",
533
+ "\n",
534
+ "def train_loop(inputs, labels, net, optimizer, scheduler):\n",
535
+ " global running_loss\n",
536
+ " global running_iou\n",
537
+ " global running_count\n",
538
+ " global running_accuracy\n",
539
+ "\n",
540
+ " # zero the parameter gradients\n",
541
+ " optimizer.zero_grad()\n",
542
+ " net = net.cuda()\n",
543
+ "\n",
544
+ " # forward + backward + optimize\n",
545
+ " outputs = net(inputs.cuda())\n",
546
+ " loss = criterion(outputs[\"out\"], labels.long().cuda())\n",
547
+ " loss.backward()\n",
548
+ " optimizer.step()\n",
549
+ " scheduler.step()\n",
550
+ "\n",
551
+ " running_loss += loss\n",
552
+ " running_iou += computeIOU(outputs[\"out\"], labels.cuda())\n",
553
+ " running_accuracy += computeAccuracy(outputs[\"out\"], labels.cuda())\n",
554
+ " running_count += 1"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "markdown",
559
+ "metadata": {
560
+ "id": "iM3Jz__hgshh"
561
+ },
562
+ "source": [
563
+ "Define validation loop"
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "code",
568
+ "execution_count": null,
569
+ "metadata": {
570
+ "id": "_GmVaoRvguic"
571
+ },
572
+ "outputs": [],
573
+ "source": [
574
+ "valid_losses = []\n",
575
+ "valid_accuracies = []\n",
576
+ "valid_ious = []\n",
577
+ "\n",
578
+ "def validation_loop(validation_data_loader, net):\n",
579
+ " global running_loss\n",
580
+ " global running_iou\n",
581
+ " global running_count\n",
582
+ " global running_accuracy\n",
583
+ " global max_valid_iou\n",
584
+ "\n",
585
+ " global training_losses\n",
586
+ " global training_accuracies\n",
587
+ " global training_ious\n",
588
+ " global valid_losses\n",
589
+ " global valid_accuracies\n",
590
+ " global valid_ious\n",
591
+ "\n",
592
+ " net = net.eval()\n",
593
+ " net = net.cuda()\n",
594
+ " count = 0\n",
595
+ " iou = 0\n",
596
+ " loss = 0\n",
597
+ " accuracy = 0\n",
598
+ " with torch.no_grad():\n",
599
+ " for (images, labels) in validation_data_loader:\n",
600
+ " net = net.cuda()\n",
601
+ " outputs = net(images.cuda())\n",
602
+ " valid_loss = criterion(outputs[\"out\"], labels.long().cuda())\n",
603
+ " valid_iou = computeIOU(outputs[\"out\"], labels.cuda())\n",
604
+ " valid_accuracy = computeAccuracy(outputs[\"out\"], labels.cuda())\n",
605
+ " iou += valid_iou\n",
606
+ " loss += valid_loss\n",
607
+ " accuracy += valid_accuracy\n",
608
+ " count += 1\n",
609
+ "\n",
610
+ " iou = iou / count\n",
611
+ " accuracy = accuracy / count\n",
612
+ "\n",
613
+ " if iou > max_valid_iou:\n",
614
+ " max_valid_iou = iou\n",
615
+ " save_path = os.path.join(\"checkpoints\", \"{}_{}_{}.cp\".format(RUNNAME, i, iou.item()))\n",
616
+ " torch.save(net.state_dict(), save_path)\n",
617
+ " print(\"model saved at\", save_path)\n",
618
+ "\n",
619
+ " loss = loss / count\n",
620
+ " print(\"Training Loss:\", running_loss / running_count)\n",
621
+ " print(\"Training IOU:\", running_iou / running_count)\n",
622
+ " print(\"Training Accuracy:\", running_accuracy / running_count)\n",
623
+ " print(\"Validation Loss:\", loss)\n",
624
+ " print(\"Validation IOU:\", iou)\n",
625
+ " print(\"Validation Accuracy:\", accuracy)\n",
626
+ "\n",
627
+ "\n",
628
+ " training_losses.append(running_loss / running_count)\n",
629
+ " training_accuracies.append(running_accuracy / running_count)\n",
630
+ " training_ious.append(running_iou / running_count)\n",
631
+ " valid_losses.append(loss)\n",
632
+ " valid_accuracies.append(accuracy)\n",
633
+ " valid_ious.append(iou)"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "markdown",
638
+ "metadata": {
639
+ "id": "DBMattYshiUj"
640
+ },
641
+ "source": [
642
+ "Define testing loop (here, you can replace assessment metrics)."
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": null,
648
+ "metadata": {
649
+ "id": "mI_mhL_ehjot"
650
+ },
651
+ "outputs": [],
652
+ "source": [
653
+ "def test_loop(test_data_loader, net):\n",
654
+ " net = net.eval()\n",
655
+ " net = net.cuda()\n",
656
+ " count = 0\n",
657
+ " iou = 0\n",
658
+ " loss = 0\n",
659
+ " accuracy = 0\n",
660
+ " with torch.no_grad():\n",
661
+ " for (images, labels) in tqdm(test_data_loader):\n",
662
+ " net = net.cuda()\n",
663
+ " outputs = net(images.cuda())\n",
664
+ " valid_loss = criterion(outputs[\"out\"], labels.long().cuda())\n",
665
+ " valid_iou = computeIOU(outputs[\"out\"], labels.cuda())\n",
666
+ " iou += valid_iou\n",
667
+ " accuracy += computeAccuracy(outputs[\"out\"], labels.cuda())\n",
668
+ " count += 1\n",
669
+ "\n",
670
+ " iou = iou / count\n",
671
+ " print(\"Test IOU:\", iou)\n",
672
+ " print(\"Test Accuracy:\", accuracy / count)"
673
+ ]
674
+ },
675
+ {
676
+ "cell_type": "markdown",
677
+ "metadata": {
678
+ "id": "cy9Fii06h17Q"
679
+ },
680
+ "source": [
681
+ "Define training and validation scheme"
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": null,
687
+ "metadata": {
688
+ "id": "NZuKVC6wh4Go"
689
+ },
690
+ "outputs": [],
691
+ "source": [
692
+ "from tqdm.notebook import tqdm\n",
693
+ "from IPython.display import clear_output\n",
694
+ "\n",
695
+ "running_loss = 0\n",
696
+ "running_iou = 0\n",
697
+ "running_count = 0\n",
698
+ "running_accuracy = 0\n",
699
+ "\n",
700
+ "training_losses = []\n",
701
+ "training_accuracies = []\n",
702
+ "training_ious = []\n",
703
+ "valid_losses = []\n",
704
+ "valid_accuracies = []\n",
705
+ "valid_ious = []\n",
706
+ "\n",
707
+ "\n",
708
+ "def train_epoch(net, optimizer, scheduler, train_iter):\n",
709
+ " for (inputs, labels) in tqdm(train_iter):\n",
710
+ " train_loop(inputs.cuda(), labels.cuda(), net.cuda(), optimizer, scheduler)\n",
711
+ "\n",
712
+ "\n",
713
+ "def train_validation_loop(net, optimizer, scheduler, train_loader,\n",
714
+ " valid_loader, num_epochs, cur_epoch):\n",
715
+ " global running_loss\n",
716
+ " global running_iou\n",
717
+ " global running_count\n",
718
+ " global running_accuracy\n",
719
+ " net = net.train()\n",
720
+ " running_loss = 0\n",
721
+ " running_iou = 0\n",
722
+ " running_count = 0\n",
723
+ " running_accuracy = 0\n",
724
+ "\n",
725
+ " for i in tqdm(range(num_epochs)):\n",
726
+ " train_iter = iter(train_loader)\n",
727
+ " train_epoch(net, optimizer, scheduler, train_iter)\n",
728
+ " clear_output()\n",
729
+ "\n",
730
+ " print(\"Current Epoch:\", cur_epoch)\n",
731
+ " validation_loop(iter(valid_loader), net)"
732
+ ]
733
+ },
734
+ {
735
+ "cell_type": "markdown",
736
+ "metadata": {
737
+ "id": "k3I88aY5iAWD"
738
+ },
739
+ "source": [
740
+ "Train model and assess metrics over epochs"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "execution_count": null,
746
+ "metadata": {
747
+ "id": "8MRpxUGWiDTu"
748
+ },
749
+ "outputs": [],
750
+ "source": [
751
+ "import os\n",
752
+ "from IPython.display import display\n",
753
+ "import matplotlib.pyplot as plt\n",
754
+ "\n",
755
+ "max_valid_iou = 0\n",
756
+ "start = 0\n",
757
+ "\n",
758
+ "epochs = []\n",
759
+ "training_losses = []\n",
760
+ "training_accuracies = []\n",
761
+ "training_ious = []\n",
762
+ "valid_losses = []\n",
763
+ "valid_accuracies = []\n",
764
+ "valid_ious = []\n",
765
+ "\n",
766
+ "for i in range(start, 1000):\n",
767
+ " train_validation_loop(net, optimizer, scheduler, train_loader, valid_loader, 10, i)\n",
768
+ " epochs.append(i)\n",
769
+ " x = epochs\n",
770
+ " plt.plot(x, training_losses, label='training losses')\n",
771
+ " plt.plot(x, training_accuracies, 'tab:orange', label='training accuracy')\n",
772
+ " plt.plot(x, training_ious, 'tab:purple', label='training iou')\n",
773
+ " plt.plot(x, valid_losses, label='valid losses')\n",
774
+ " plt.plot(x, valid_accuracies, 'tab:red',label='valid accuracy')\n",
775
+ " plt.plot(x, valid_ious, 'tab:green',label='valid iou')\n",
776
+ " plt.legend(loc=\"upper left\")\n",
777
+ "\n",
778
+ " display(plt.show())\n",
779
+ "\n",
780
+ " print(\"max valid iou:\", max_valid_iou)"
781
+ ]
782
+ }
783
+ ],
784
+ "metadata": {
785
+ "accelerator": "GPU",
786
+ "colab": {
787
+ "name": "Train.ipynb",
788
+ "provenance": []
789
+ },
790
+ "kernelspec": {
791
+ "display_name": "Python 3",
792
+ "language": "python",
793
+ "name": "python3"
794
+ },
795
+ "language_info": {
796
+ "codemirror_mode": {
797
+ "name": "ipython",
798
+ "version": 3
799
+ },
800
+ "file_extension": ".py",
801
+ "mimetype": "text/x-python",
802
+ "name": "python",
803
+ "nbconvert_exporter": "python",
804
+ "pygments_lexer": "ipython3",
805
+ "version": "3.8.5"
806
+ },
807
+ "toc": {
808
+ "base_numbering": 1,
809
+ "nav_menu": {},
810
+ "number_sections": true,
811
+ "sideBar": true,
812
+ "skip_h1_title": false,
813
+ "title_cell": "Table of Contents",
814
+ "title_sidebar": "Contents",
815
+ "toc_cell": false,
816
+ "toc_position": {},
817
+ "toc_section_display": true,
818
+ "toc_window_display": false
819
+ }
820
+ },
821
+ "nbformat": 4,
822
+ "nbformat_minor": 0
823
+ }