chainyo commited on
Commit
0719c14
1 Parent(s): 3110ea7

test notebook

Browse files
Files changed (1) hide show
  1. finetuning.ipynb +974 -0
finetuning.ipynb ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import json\n",
10
+ "import pytorch_lightning as pl\n",
11
+ "import torch\n",
12
+ "import torchmetrics\n",
13
+ "\n",
14
+ "from datasets import load_dataset, load_metric\n",
15
+ "\n",
16
+ "from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation\n",
17
+ "\n",
18
+ "from torch import nn\n",
19
+ "from torch.utils.data import DataLoader, Dataset, random_split\n",
20
+ "\n",
21
+ "from tqdm.notebook import tqdm"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 4,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "class SemanticSegmentationDataset(Dataset):\n",
31
+ " \"\"\"Image segmentation datasets.\"\"\"\n",
32
+ "\n",
33
+ " def __init__(\n",
34
+ " self, \n",
35
+ " dataset: torch.utils.data.dataset.Subset, \n",
36
+ " feature_extractor = SegformerFeatureExtractor(reduce_labels=True),\n",
37
+ " ):\n",
38
+ " \"\"\"\n",
39
+ " Initialize the dataset with the given feature extractor and split.\n",
40
+ "\n",
41
+ " Parameters\n",
42
+ " ----------\n",
43
+ " hub_dir : Dataset\n",
44
+ " The dataset to use.\n",
45
+ " feature_extractor : FeatureExtractor, optional\n",
46
+ " The feature extractor to use. The default is SegformerFeatureExtractor.\n",
47
+ " \"\"\"\n",
48
+ " self.dataset = dataset\n",
49
+ " self.feature_extractor = feature_extractor\n",
50
+ " self.length = len(self.dataset)\n",
51
+ " print(f\"Loaded {self.length} samples.\")\n",
52
+ "\n",
53
+ "\n",
54
+ " def __len__(self):\n",
55
+ " \"\"\"Return the number of samples in the dataset.\"\"\"\n",
56
+ " return self.length\n",
57
+ "\n",
58
+ "\n",
59
+ " def __getitem__(self, index: int):\n",
60
+ " \"\"\"Get the sample at the given index.\"\"\"\n",
61
+ " image = self.dataset[index][\"pixel_values\"]\n",
62
+ " label = self.dataset[index][\"label\"]\n",
63
+ "\n",
64
+ " encoded_inputs = self.feature_extractor(image, label, return_tensors=\"pt\")\n",
65
+ "\n",
66
+ " for k, v in encoded_inputs.items():\n",
67
+ " encoded_inputs[k].squeeze_() # remove batch dimension\n",
68
+ "\n",
69
+ " return encoded_inputs"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 3,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "BATCH_SIZE = 32\n",
79
+ "HUB_DIR = \"segments/sidewalk-semantic\"\n",
80
+ "EPOCHS = 200"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 5,
86
+ "metadata": {},
87
+ "outputs": [
88
+ {
89
+ "name": "stderr",
90
+ "output_type": "stream",
91
+ "text": [
92
+ "Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n",
93
+ "Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n"
94
+ ]
95
+ },
96
+ {
97
+ "name": "stdout",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "Loaded 800 samples.\n",
101
+ "Loaded 200 samples.\n"
102
+ ]
103
+ }
104
+ ],
105
+ "source": [
106
+ "dataset = load_dataset(HUB_DIR, split=\"train\")\n",
107
+ "\n",
108
+ "train_dataset, val_dataset = random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])\n",
109
+ "train_dataset = SemanticSegmentationDataset(train_dataset)\n",
110
+ "val_dataset = SemanticSegmentationDataset(val_dataset)"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 5,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
120
+ "val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 6,
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "name": "stdout",
130
+ "output_type": "stream",
131
+ "text": [
132
+ "pixel_values torch.Size([32, 3, 512, 512])\n",
133
+ "labels torch.Size([32, 512, 512])\n"
134
+ ]
135
+ }
136
+ ],
137
+ "source": [
138
+ "batch = next(iter(train_dataloader))\n",
139
+ "\n",
140
+ "for k, v in batch.items():\n",
141
+ " print(k, v.shape)"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": 7,
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "# class SidewalkSegmentationModel(pl.LightningModule):\n",
151
+ "# def __init__(self, num_classes: int, learning_rate: float = 6e-5):\n",
152
+ "# super().__init__()\n",
153
+ "# self.model = SegformerForSemanticSegmentation.from_pretrained(\n",
154
+ "# \"nvidia/mit-b0\", num_labels=num_classes, id2label=id2label, label2id=label2id,\n",
155
+ "# )\n",
156
+ "# self.learning_rate = learning_rate\n",
157
+ "# self.metric = load_metric(\"mean_iou\")\n",
158
+ "\n",
159
+ " \n",
160
+ "# def forward(self, *args, **kwargs):\n",
161
+ "# return self.model(*args, **kwargs)\n",
162
+ "\n",
163
+ " \n",
164
+ "# def training_step(self, batch, batch_idx):\n",
165
+ "# pixel_values = batch[\"pixel_values\"]\n",
166
+ "# labels = batch[\"labels\"]\n",
167
+ "\n",
168
+ "# outputs = self(pixel_values=pixel_values, labels=labels)\n",
169
+ "# loss, logits = outputs.loss, outputs.logits\n",
170
+ "\n",
171
+ " \n",
172
+ "# def configure_optimizers(self) -> torch.optim.AdamW:\n",
173
+ "# \"\"\"\n",
174
+ "# Configure the optimizer.\n",
175
+ "# Returns\n",
176
+ "# -------\n",
177
+ "# torch.optim.AdamW\n",
178
+ "# Optimizer for the model\n",
179
+ "# \"\"\"\n",
180
+ "# return torch.optim.AdamW(model.parameters(), lr=self.learning_rate)"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 11,
186
+ "metadata": {},
187
+ "outputs": [
188
+ {
189
+ "name": "stdout",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "{0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}\n"
193
+ ]
194
+ },
195
+ {
196
+ "name": "stderr",
197
+ "output_type": "stream",
198
+ "text": [
199
+ "Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']\n",
200
+ "- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
201
+ "- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
202
+ "Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.classifier.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.1.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_c.2.proj.bias']\n",
203
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
204
+ ]
205
+ }
206
+ ],
207
+ "source": [
208
+ "id2label_file = json.load(open(\"id2label.json\", \"r\"))\n",
209
+ "id2label = {int(k): v for k, v in id2label_file.items()}\n",
210
+ "print(id2label)\n",
211
+ "label2id = {v: k for k, v in id2label_file.items()}\n",
212
+ "num_labels = len(id2label)\n",
213
+ "\n",
214
+ "model = SegformerForSemanticSegmentation.from_pretrained(\n",
215
+ " \"nvidia/mit-b0\", num_labels=num_labels, id2label=id2label, label2id=label2id,\n",
216
+ ")"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 9,
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "metric = load_metric(\"mean_iou\")"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 10,
231
+ "metadata": {},
232
+ "outputs": [
233
+ {
234
+ "data": {
235
+ "application/vnd.jupyter.widget-view+json": {
236
+ "model_id": "091f8e1587f64625a4bbbf04f13a840e",
237
+ "version_major": 2,
238
+ "version_minor": 0
239
+ },
240
+ "text/plain": [
241
+ " 0%| | 0/25 [00:00<?, ?it/s]"
242
+ ]
243
+ },
244
+ "metadata": {},
245
+ "output_type": "display_data"
246
+ },
247
+ {
248
+ "name": "stderr",
249
+ "output_type": "stream",
250
+ "text": [
251
+ "/home/chainyo/.cache/huggingface/modules/datasets_modules/metrics/mean_iou/d4add40cf977cdd73590b5873fa830f3f13adb678f6777a29fb07b7c81d14342/mean_iou.py:259: RuntimeWarning: invalid value encountered in true_divide\n",
252
+ " acc = total_area_intersect / total_area_label\n"
253
+ ]
254
+ },
255
+ {
256
+ "name": "stdout",
257
+ "output_type": "stream",
258
+ "text": [
259
+ "Epoch 0/200 Batch 0/25 Loss 3.5735 Metrics {'mean_iou': 0.005477220659286406, 'mean_accuracy': 0.03572801337234697, 'overall_accuracy': 0.0286087999162653, 'per_category_iou': array([2.25093889e-02, 1.39043249e-03, 7.61652063e-03, 1.08658706e-02,\n",
260
+ " 1.00475237e-02, 0.00000000e+00, 2.03193907e-04, 1.38439962e-03,\n",
261
+ " 3.06388194e-05, 2.21495741e-03, 2.81951515e-05, 0.00000000e+00,\n",
262
+ " 0.00000000e+00, 1.13529929e-04, 0.00000000e+00, 0.00000000e+00,\n",
263
+ " 0.00000000e+00, 7.05787860e-03, 8.48350742e-03, 2.38476991e-03,\n",
264
+ " 1.76276224e-03, 2.17775404e-03, 0.00000000e+00, 1.04358386e-03,\n",
265
+ " 4.58714414e-04, 6.41413308e-05, 1.11431787e-04, 7.99247870e-02,\n",
266
+ " 8.28409255e-03, 2.07198485e-03, 6.12199013e-04, 6.53992687e-03,\n",
267
+ " 1.42611461e-02, 5.93919660e-05, 0.00000000e+00]), 'per_category_accuracy': array([2.49789663e-02, 1.40417891e-03, 5.70403118e-02, 1.19771803e-02,\n",
268
+ " 1.35858556e-02, nan, 2.43287079e-04, 1.42133234e-02,\n",
269
+ " 9.06344411e-03, 2.86141687e-03, 1.51057402e-03, nan,\n",
270
+ " nan, 1.54786781e-04, 0.00000000e+00, 0.00000000e+00,\n",
271
+ " nan, 7.59912501e-03, 1.08030661e-01, 7.46983779e-03,\n",
272
+ " 3.63612647e-03, 2.03376823e-01, nan, 1.34638923e-02,\n",
273
+ " 4.23971037e-03, 2.44969379e-03, 2.61780105e-02, 1.44916888e-01,\n",
274
+ " 1.09119869e-02, 2.12033947e-03, 5.77414410e-02, 8.69919527e-02,\n",
275
+ " 1.99881736e-01, 2.00708383e-02, nan])}\n"
276
+ ]
277
+ },
278
+ {
279
+ "data": {
280
+ "application/vnd.jupyter.widget-view+json": {
281
+ "model_id": "8b14f2d3fe2044bf8940697dfce1f2d9",
282
+ "version_major": 2,
283
+ "version_minor": 0
284
+ },
285
+ "text/plain": [
286
+ " 0%| | 0/25 [00:00<?, ?it/s]"
287
+ ]
288
+ },
289
+ "metadata": {},
290
+ "output_type": "display_data"
291
+ },
292
+ {
293
+ "name": "stdout",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "Epoch 1/200 Batch 0/25 Loss 2.3349 Metrics {'mean_iou': 0.08249196196840773, 'mean_accuracy': 0.12242777101928329, 'overall_accuracy': 0.5340892205419953, 'per_category_iou': array([3.12479410e-01, 6.01364321e-01, 1.11562075e-02, 2.09116315e-02,\n",
297
+ " 1.95069293e-02, 0.00000000e+00, 1.09903909e-04, 1.13759900e-03,\n",
298
+ " 1.31593453e-03, 4.16974851e-01, 9.43479560e-04, 0.00000000e+00,\n",
299
+ " 0.00000000e+00, 2.71726824e-05, 7.02479634e-05, 1.71339744e-06,\n",
300
+ " 1.01240638e-04, 3.64420274e-01, 5.22553547e-03, 1.01179313e-03,\n",
301
+ " 5.40034112e-03, 6.14980455e-04, 0.00000000e+00, 1.37384839e-03,\n",
302
+ " 1.58339239e-03, 2.28333709e-05, 4.11671538e-05, 5.37431493e-01,\n",
303
+ " 6.31855647e-02, 4.88071958e-01, 5.19480641e-03, 4.67615947e-03,\n",
304
+ " 2.28171525e-02, 4.67269054e-05, 0.00000000e+00]), 'per_category_accuracy': array([4.06453404e-01, 7.78244778e-01, 2.33772733e-02, 2.33628638e-02,\n",
305
+ " 2.53283555e-02, nan, 1.11846810e-04, 3.42729695e-03,\n",
306
+ " 9.61392116e-03, 6.59758916e-01, 1.07169352e-02, 0.00000000e+00,\n",
307
+ " 0.00000000e+00, 1.71831147e-04, 1.20609014e-04, 1.32753642e-04,\n",
308
+ " 7.55138223e-03, 5.08304753e-01, 2.45514876e-02, 1.15429656e-03,\n",
309
+ " 6.00626887e-03, 1.89040602e-02, 0.00000000e+00, 3.03657284e-03,\n",
310
+ " 3.07788162e-03, 1.93985207e-04, 1.35743374e-03, 8.37114885e-01,\n",
311
+ " 7.86780250e-02, 5.18267366e-01, 1.14241257e-02, 1.51054789e-02,\n",
312
+ " 6.31875993e-02, 1.38005816e-03, nan])}\n"
313
+ ]
314
+ },
315
+ {
316
+ "data": {
317
+ "application/vnd.jupyter.widget-view+json": {
318
+ "model_id": "1bcd35da7d4a43a0958776974bfa7918",
319
+ "version_major": 2,
320
+ "version_minor": 0
321
+ },
322
+ "text/plain": [
323
+ " 0%| | 0/25 [00:00<?, ?it/s]"
324
+ ]
325
+ },
326
+ "metadata": {},
327
+ "output_type": "display_data"
328
+ },
329
+ {
330
+ "name": "stderr",
331
+ "output_type": "stream",
332
+ "text": [
333
+ "/home/chainyo/.cache/huggingface/modules/datasets_modules/metrics/mean_iou/d4add40cf977cdd73590b5873fa830f3f13adb678f6777a29fb07b7c81d14342/mean_iou.py:258: RuntimeWarning: invalid value encountered in true_divide\n",
334
+ " iou = total_area_intersect / total_area_union\n"
335
+ ]
336
+ },
337
+ {
338
+ "name": "stdout",
339
+ "output_type": "stream",
340
+ "text": [
341
+ "Epoch 2/200 Batch 0/25 Loss 1.9441 Metrics {'mean_iou': 0.1115145961827685, 'mean_accuracy': 0.1571059701593332, 'overall_accuracy': 0.6731429879739427, 'per_category_iou': array([4.45592658e-01, 7.27905738e-01, 3.53301085e-05, 1.04372074e-02,\n",
342
+ " 7.95093029e-03, nan, 1.38316668e-06, 0.00000000e+00,\n",
343
+ " 0.00000000e+00, 4.80050947e-01, 0.00000000e+00, 0.00000000e+00,\n",
344
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
345
+ " 0.00000000e+00, 4.95525330e-01, 0.00000000e+00, 1.11859236e-05,\n",
346
+ " 8.15231791e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
347
+ " 2.30751037e-06, 0.00000000e+00, 0.00000000e+00, 6.24028521e-01,\n",
348
+ " 1.05303497e-01, 7.82230424e-01, 0.00000000e+00, 0.00000000e+00,\n",
349
+ " 9.05398708e-04, 0.00000000e+00, nan]), 'per_category_accuracy': array([7.33079709e-01, 9.03717795e-01, 3.57307878e-05, 1.05455168e-02,\n",
350
+ " 8.21302511e-03, nan, 1.38319984e-06, 0.00000000e+00,\n",
351
+ " 0.00000000e+00, 8.08957376e-01, 0.00000000e+00, 0.00000000e+00,\n",
352
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
353
+ " 0.00000000e+00, 8.13886465e-01, 0.00000000e+00, 1.11898068e-05,\n",
354
+ " 8.19780315e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
355
+ " 2.30767678e-06, 0.00000000e+00, 0.00000000e+00, 9.49630788e-01,\n",
356
+ " 1.14320143e-01, 8.41172201e-01, 0.00000000e+00, 0.00000000e+00,\n",
357
+ " 9.22565695e-04, 0.00000000e+00, nan])}\n"
358
+ ]
359
+ },
360
+ {
361
+ "data": {
362
+ "application/vnd.jupyter.widget-view+json": {
363
+ "model_id": "f3ef333a6137481180b4ec59efd1673e",
364
+ "version_major": 2,
365
+ "version_minor": 0
366
+ },
367
+ "text/plain": [
368
+ " 0%| | 0/25 [00:00<?, ?it/s]"
369
+ ]
370
+ },
371
+ "metadata": {},
372
+ "output_type": "display_data"
373
+ },
374
+ {
375
+ "name": "stdout",
376
+ "output_type": "stream",
377
+ "text": [
378
+ "Epoch 3/200 Batch 0/25 Loss 1.7086 Metrics {'mean_iou': 0.13337084618347653, 'mean_accuracy': 0.17885172041247846, 'overall_accuracy': 0.7166871222915292, 'per_category_iou': array([0.49122674, 0.77341676, 0. , 0.1407711 , 0.00553198,\n",
379
+ " nan, 0. , 0. , 0. , 0.55152752,\n",
380
+ " 0. , 0. , 0. , 0. , 0. ,\n",
381
+ " 0. , 0. , 0.51584332, 0. , 0. ,\n",
382
+ " 0. , 0. , 0. , 0. , 0. ,\n",
383
+ " 0. , 0. , 0.68907817, 0.42107346, 0.81276888,\n",
384
+ " 0. , 0. , 0. , 0. , nan]), 'per_category_accuracy': array([0.82412545, 0.91420412, 0. , 0.14242155, 0.00561311,\n",
385
+ " nan, 0. , 0. , 0. , 0.84759725,\n",
386
+ " 0. , 0. , 0. , 0. , 0. ,\n",
387
+ " 0. , 0. , 0.85361568, 0. , 0. ,\n",
388
+ " 0. , 0. , 0. , 0. , 0. ,\n",
389
+ " 0. , 0. , 0.93567573, 0.47723124, 0.90162263,\n",
390
+ " 0. , 0. , 0. , 0. , nan])}\n"
391
+ ]
392
+ },
393
+ {
394
+ "data": {
395
+ "application/vnd.jupyter.widget-view+json": {
396
+ "model_id": "384340072f794e70bcbbc74dacc50c21",
397
+ "version_major": 2,
398
+ "version_minor": 0
399
+ },
400
+ "text/plain": [
401
+ " 0%| | 0/25 [00:00<?, ?it/s]"
402
+ ]
403
+ },
404
+ "metadata": {},
405
+ "output_type": "display_data"
406
+ },
407
+ {
408
+ "name": "stdout",
409
+ "output_type": "stream",
410
+ "text": [
411
+ "Epoch 4/200 Batch 0/25 Loss 1.4155 Metrics {'mean_iou': 0.15564168770290954, 'mean_accuracy': 0.20019284726893685, 'overall_accuracy': 0.7574203051314757, 'per_category_iou': array([5.71505637e-01, 8.01399129e-01, 0.00000000e+00, 4.92403441e-01,\n",
412
+ " 5.41554099e-03, nan, 2.31577543e-07, 0.00000000e+00,\n",
413
+ " 0.00000000e+00, 5.86567051e-01, 0.00000000e+00, 0.00000000e+00,\n",
414
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
415
+ " 0.00000000e+00, 5.51951880e-01, 0.00000000e+00, 0.00000000e+00,\n",
416
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
417
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.24133440e-01,\n",
418
+ " 5.67064833e-01, 8.35733013e-01, 0.00000000e+00, 0.00000000e+00,\n",
419
+ " 1.49737193e-06, 0.00000000e+00, nan]), 'per_category_accuracy': array([8.47391615e-01, 9.31621860e-01, 0.00000000e+00, 5.38104498e-01,\n",
420
+ " 5.46801709e-03, nan, 2.31577972e-07, 0.00000000e+00,\n",
421
+ " 0.00000000e+00, 8.70377721e-01, 0.00000000e+00, 0.00000000e+00,\n",
422
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
423
+ " 0.00000000e+00, 8.79001612e-01, 0.00000000e+00, 0.00000000e+00,\n",
424
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
425
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.26588036e-01,\n",
426
+ " 6.90579673e-01, 9.17229200e-01, 0.00000000e+00, 0.00000000e+00,\n",
427
+ " 1.49737361e-06, 0.00000000e+00, nan])}\n"
428
+ ]
429
+ },
430
+ {
431
+ "data": {
432
+ "application/vnd.jupyter.widget-view+json": {
433
+ "model_id": "428849c8d6ce414887a0661d3d4625da",
434
+ "version_major": 2,
435
+ "version_minor": 0
436
+ },
437
+ "text/plain": [
438
+ " 0%| | 0/25 [00:00<?, ?it/s]"
439
+ ]
440
+ },
441
+ "metadata": {},
442
+ "output_type": "display_data"
443
+ },
444
+ {
445
+ "name": "stdout",
446
+ "output_type": "stream",
447
+ "text": [
448
+ "Epoch 5/200 Batch 0/25 Loss 1.2700 Metrics {'mean_iou': 0.16264865782481297, 'mean_accuracy': 0.20757696789528982, 'overall_accuracy': 0.7689239961674764, 'per_category_iou': array([0.58414589, 0.81498541, 0. , 0.57467831, 0.00689438,\n",
449
+ " nan, 0. , 0. , 0. , 0.61111453,\n",
450
+ " 0. , 0. , 0. , 0. , 0. ,\n",
451
+ " 0. , 0. , 0.55374014, 0. , 0. ,\n",
452
+ " 0. , 0. , 0. , 0. , 0. ,\n",
453
+ " 0. , 0. , 0.74374917, 0.6398338 , 0.83826408,\n",
454
+ " 0. , 0. , 0. , 0. , nan]), 'per_category_accuracy': array([0.86273737, 0.93244115, 0. , 0.66435561, 0.00696605,\n",
455
+ " nan, 0. , 0. , 0. , 0.87784737,\n",
456
+ " 0. , 0. , 0. , 0. , 0. ,\n",
457
+ " 0. , 0. , 0.87298997, 0. , 0. ,\n",
458
+ " 0. , 0. , 0. , 0. , 0. ,\n",
459
+ " 0. , 0. , 0.92712481, 0.78796366, 0.91761394,\n",
460
+ " 0. , 0. , 0. , 0. , nan])}\n"
461
+ ]
462
+ },
463
+ {
464
+ "data": {
465
+ "application/vnd.jupyter.widget-view+json": {
466
+ "model_id": "1577eed162ec4162a69493365c950329",
467
+ "version_major": 2,
468
+ "version_minor": 0
469
+ },
470
+ "text/plain": [
471
+ " 0%| | 0/25 [00:00<?, ?it/s]"
472
+ ]
473
+ },
474
+ "metadata": {},
475
+ "output_type": "display_data"
476
+ },
477
+ {
478
+ "name": "stdout",
479
+ "output_type": "stream",
480
+ "text": [
481
+ "Epoch 6/200 Batch 0/25 Loss 1.1401 Metrics {'mean_iou': 0.16748948766093116, 'mean_accuracy': 0.21181445128118748, 'overall_accuracy': 0.7795023598828317, 'per_category_iou': array([6.09680179e-01, 8.31918538e-01, 0.00000000e+00, 6.34236889e-01,\n",
482
+ " 1.24257235e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
483
+ " 0.00000000e+00, 6.35402026e-01, 0.00000000e+00, 0.00000000e+00,\n",
484
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
485
+ " 0.00000000e+00, 5.64656796e-01, 0.00000000e+00, 3.65611521e-06,\n",
486
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
487
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.53253595e-01,\n",
488
+ " 6.35493134e-01, 8.50082556e-01, 0.00000000e+00, 0.00000000e+00,\n",
489
+ " 0.00000000e+00, 0.00000000e+00, nan]), 'per_category_accuracy': array([8.81516256e-01, 9.41665624e-01, 0.00000000e+00, 7.39806944e-01,\n",
490
+ " 1.26589939e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
491
+ " 0.00000000e+00, 8.90136686e-01, 0.00000000e+00, 0.00000000e+00,\n",
492
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
493
+ " 0.00000000e+00, 8.94297220e-01, 0.00000000e+00, 3.65611521e-06,\n",
494
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
495
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.26219754e-01,\n",
496
+ " 7.75692895e-01, 9.27878865e-01, 0.00000000e+00, 0.00000000e+00,\n",
497
+ " 0.00000000e+00, 0.00000000e+00, nan])}\n"
498
+ ]
499
+ },
500
+ {
501
+ "data": {
502
+ "application/vnd.jupyter.widget-view+json": {
503
+ "model_id": "169ae321c5ac4a4aab54a816c05035f2",
504
+ "version_major": 2,
505
+ "version_minor": 0
506
+ },
507
+ "text/plain": [
508
+ " 0%| | 0/25 [00:00<?, ?it/s]"
509
+ ]
510
+ },
511
+ "metadata": {},
512
+ "output_type": "display_data"
513
+ },
514
+ {
515
+ "name": "stdout",
516
+ "output_type": "stream",
517
+ "text": [
518
+ "Epoch 7/200 Batch 0/25 Loss 1.0940 Metrics {'mean_iou': 0.17137856945919164, 'mean_accuracy': 0.21495584433690398, 'overall_accuracy': 0.7881396456781556, 'per_category_iou': array([6.32120417e-01, 8.39814384e-01, 0.00000000e+00, 6.53276797e-01,\n",
519
+ " 2.36145329e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
520
+ " 0.00000000e+00, 6.44513271e-01, 0.00000000e+00, 0.00000000e+00,\n",
521
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
522
+ " 0.00000000e+00, 5.70917228e-01, 0.00000000e+00, 5.49670921e-05,\n",
523
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
524
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.63117039e-01,\n",
525
+ " 6.70860701e-01, 8.57197972e-01, 0.00000000e+00, 0.00000000e+00,\n",
526
+ " 5.48363063e-06, 0.00000000e+00, nan]), 'per_category_accuracy': array([9.08405231e-01, 9.45289137e-01, 0.00000000e+00, 7.43501061e-01,\n",
527
+ " 2.43273947e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
528
+ " 0.00000000e+00, 8.96271845e-01, 0.00000000e+00, 0.00000000e+00,\n",
529
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
530
+ " 0.00000000e+00, 8.93162955e-01, 0.00000000e+00, 5.49677426e-05,\n",
531
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
532
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.25997998e-01,\n",
533
+ " 8.26946192e-01, 9.29580599e-01, 0.00000000e+00, 0.00000000e+00,\n",
534
+ " 5.48370681e-06, 0.00000000e+00, nan])}\n"
535
+ ]
536
+ },
537
+ {
538
+ "data": {
539
+ "application/vnd.jupyter.widget-view+json": {
540
+ "model_id": "544f85de17a04a3ca74001e2fbbc6bc9",
541
+ "version_major": 2,
542
+ "version_minor": 0
543
+ },
544
+ "text/plain": [
545
+ " 0%| | 0/25 [00:00<?, ?it/s]"
546
+ ]
547
+ },
548
+ "metadata": {},
549
+ "output_type": "display_data"
550
+ },
551
+ {
552
+ "name": "stdout",
553
+ "output_type": "stream",
554
+ "text": [
555
+ "Epoch 8/200 Batch 0/25 Loss 0.9463 Metrics {'mean_iou': 0.17478409568512296, 'mean_accuracy': 0.21773480038212398, 'overall_accuracy': 0.7927706422623841, 'per_category_iou': array([6.29223165e-01, 8.42392089e-01, 3.52355169e-04, 6.80396607e-01,\n",
556
+ " 5.21568283e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
557
+ " 0.00000000e+00, 6.55000862e-01, 0.00000000e+00, 0.00000000e+00,\n",
558
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
559
+ " 0.00000000e+00, 5.75092138e-01, 0.00000000e+00, 7.12203670e-04,\n",
560
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
561
+ " 5.25684195e-05, 0.00000000e+00, 0.00000000e+00, 7.79574322e-01,\n",
562
+ " 6.87965551e-01, 8.64953847e-01, 0.00000000e+00, 0.00000000e+00,\n",
563
+ " 2.62043422e-06, 0.00000000e+00, nan]), 'per_category_accuracy': array([9.02830276e-01, 9.46588697e-01, 3.52355621e-04, 7.80701400e-01,\n",
564
+ " 5.56946042e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
565
+ " 0.00000000e+00, 8.98895404e-01, 0.00000000e+00, 0.00000000e+00,\n",
566
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
567
+ " 0.00000000e+00, 8.98825852e-01, 0.00000000e+00, 7.12358991e-04,\n",
568
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
569
+ " 5.25687464e-05, 0.00000000e+00, 0.00000000e+00, 9.33269035e-01,\n",
570
+ " 8.30703600e-01, 9.36619641e-01, 0.00000000e+00, 0.00000000e+00,\n",
571
+ " 2.62043913e-06, 0.00000000e+00, nan])}\n"
572
+ ]
573
+ },
574
+ {
575
+ "data": {
576
+ "application/vnd.jupyter.widget-view+json": {
577
+ "model_id": "a88b7f1699074168b624caaad5466071",
578
+ "version_major": 2,
579
+ "version_minor": 0
580
+ },
581
+ "text/plain": [
582
+ " 0%| | 0/25 [00:00<?, ?it/s]"
583
+ ]
584
+ },
585
+ "metadata": {},
586
+ "output_type": "display_data"
587
+ },
588
+ {
589
+ "name": "stdout",
590
+ "output_type": "stream",
591
+ "text": [
592
+ "Epoch 9/200 Batch 0/25 Loss 1.1287 Metrics {'mean_iou': 0.17933414157617142, 'mean_accuracy': 0.22211677037195932, 'overall_accuracy': 0.7957460527936768, 'per_category_iou': array([6.42468748e-01, 8.51322031e-01, 4.38690416e-02, 6.83716408e-01,\n",
593
+ " 1.05816211e-01, nan, 3.89710724e-04, 0.00000000e+00,\n",
594
+ " 0.00000000e+00, 6.68283305e-01, 0.00000000e+00, 0.00000000e+00,\n",
595
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
596
+ " 0.00000000e+00, 5.86580529e-01, 0.00000000e+00, 1.00559462e-03,\n",
597
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
598
+ " 1.15812990e-04, 0.00000000e+00, 0.00000000e+00, 7.77224737e-01,\n",
599
+ " 6.87868711e-01, 8.69336423e-01, 0.00000000e+00, 0.00000000e+00,\n",
600
+ " 2.94092022e-05, 0.00000000e+00, nan]), 'per_category_accuracy': array([9.05123309e-01, 9.50701912e-01, 4.38913906e-02, 7.96283141e-01,\n",
601
+ " 1.19598233e-01, nan, 3.89758329e-04, 0.00000000e+00,\n",
602
+ " 0.00000000e+00, 8.97971821e-01, 0.00000000e+00, 0.00000000e+00,\n",
603
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
604
+ " 0.00000000e+00, 9.09473869e-01, 0.00000000e+00, 1.00721614e-03,\n",
605
+ " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
606
+ " 1.15831636e-04, 0.00000000e+00, 0.00000000e+00, 9.30939519e-01,\n",
607
+ " 8.38613121e-01, 9.35714875e-01, 0.00000000e+00, 0.00000000e+00,\n",
608
+ " 2.94260342e-05, 0.00000000e+00, nan])}\n"
609
+ ]
610
+ },
611
+ {
612
+ "data": {
613
+ "application/vnd.jupyter.widget-view+json": {
614
+ "model_id": "569b4a3ebdd24be7811709fc8d588e30",
615
+ "version_major": 2,
616
+ "version_minor": 0
617
+ },
618
+ "text/plain": [
619
+ " 0%| | 0/25 [00:00<?, ?it/s]"
620
+ ]
621
+ },
622
+ "metadata": {},
623
+ "output_type": "display_data"
624
+ },
625
+ {
626
+ "ename": "KeyboardInterrupt",
627
+ "evalue": "",
628
+ "output_type": "error",
629
+ "traceback": [
630
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
631
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
632
+ "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 10'\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=24'>25</a>\u001b[0m metric\u001b[39m.\u001b[39madd_batch(predictions\u001b[39m=\u001b[39mpredicted\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy(), references\u001b[39m=\u001b[39mlabels\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy())\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=26'>27</a>\u001b[0m \u001b[39mif\u001b[39;00m index \u001b[39m%\u001b[39m \u001b[39m100\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=27'>28</a>\u001b[0m metrics \u001b[39m=\u001b[39m metric\u001b[39m.\u001b[39;49mcompute(num_labels\u001b[39m=\u001b[39;49mnum_labels, ignore_index\u001b[39m=\u001b[39;49m\u001b[39m255\u001b[39;49m, reduce_labels\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=28'>29</a>\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEpoch \u001b[39m\u001b[39m{\u001b[39;00mepoch\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00mEPOCHS\u001b[39m}\u001b[39;00m\u001b[39m Batch \u001b[39m\u001b[39m{\u001b[39;00mindex\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mlen\u001b[39m(train_dataloader)\u001b[39m}\u001b[39;00m\u001b[39m Loss \u001b[39m\u001b[39m{\u001b[39;00mloss\u001b[39m.\u001b[39mitem()\u001b[39m:\u001b[39;00m\u001b[39m.4f\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m Metrics \u001b[39m\u001b[39m{\u001b[39;00mmetrics\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
633
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36mMetric.compute\u001b[0;34m(self, predictions, references, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=424'>425</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=425'>426</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=427'>428</a>\u001b[0m inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=428'>429</a>\u001b[0m \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=429'>430</a>\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n",
634
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=424'>425</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=425'>426</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=427'>428</a>\u001b[0m inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=428'>429</a>\u001b[0m \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=429'>430</a>\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n",
635
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2124\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2121'>2122</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, key): \u001b[39m# noqa: F811\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2122'>2123</a>\u001b[0m \u001b[39m\"\"\"Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).\"\"\"\u001b[39;00m\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2123'>2124</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_getitem(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2124'>2125</a>\u001b[0m key,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2125'>2126</a>\u001b[0m )\n",
636
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2109\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, decoded, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2106'>2107</a>\u001b[0m formatter \u001b[39m=\u001b[39m get_formatter(format_type, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures, decoded\u001b[39m=\u001b[39mdecoded, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mformat_kwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2107'>2108</a>\u001b[0m pa_subtable \u001b[39m=\u001b[39m query_table(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_data, key, indices\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m)\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2108'>2109</a>\u001b[0m formatted_output \u001b[39m=\u001b[39m format_table(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2109'>2110</a>\u001b[0m pa_subtable, key, formatter\u001b[39m=\u001b[39;49mformatter, format_columns\u001b[39m=\u001b[39;49mformat_columns, output_all_columns\u001b[39m=\u001b[39;49moutput_all_columns\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2110'>2111</a>\u001b[0m )\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2111'>2112</a>\u001b[0m \u001b[39mreturn\u001b[39;00m formatted_output\n",
637
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:532\u001b[0m, in \u001b[0;36mformat_table\u001b[0;34m(table, key, formatter, format_columns, output_all_columns)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=529'>530</a>\u001b[0m python_formatter \u001b[39m=\u001b[39m PythonFormatter(features\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=530'>531</a>\u001b[0m \u001b[39mif\u001b[39;00m format_columns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=531'>532</a>\u001b[0m \u001b[39mreturn\u001b[39;00m formatter(pa_table, query_type\u001b[39m=\u001b[39;49mquery_type)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=532'>533</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=533'>534</a>\u001b[0m \u001b[39mif\u001b[39;00m key \u001b[39min\u001b[39;00m format_columns:\n",
638
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:283\u001b[0m, in \u001b[0;36mFormatter.__call__\u001b[0;34m(self, pa_table, query_type)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=280'>281</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_row(pa_table)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=281'>282</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=282'>283</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mformat_column(pa_table)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=283'>284</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mbatch\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=284'>285</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_batch(pa_table)\n",
639
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:316\u001b[0m, in \u001b[0;36mPythonFormatter.format_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=314'>315</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=315'>316</a>\u001b[0m column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpython_arrow_extractor()\u001b[39m.\u001b[39;49mextract_column(pa_table)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=316'>317</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdecoded:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=317'>318</a>\u001b[0m column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpython_features_decoder\u001b[39m.\u001b[39mdecode_column(column, pa_table\u001b[39m.\u001b[39mcolumn_names[\u001b[39m0\u001b[39m])\n",
640
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:143\u001b[0m, in \u001b[0;36mPythonArrowExtractor.extract_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=141'>142</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mextract_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=142'>143</a>\u001b[0m \u001b[39mreturn\u001b[39;00m pa_table\u001b[39m.\u001b[39;49mcolumn(\u001b[39m0\u001b[39;49m)\u001b[39m.\u001b[39;49mto_pylist()\n",
641
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
642
+ ]
643
+ }
644
+ ],
645
+ "source": [
646
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)\n",
647
+ "\n",
648
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
649
+ "model.to(device)\n",
650
+ "\n",
651
+ "model.train()\n",
652
+ "for epoch in range(EPOCHS):\n",
653
+ " for index, batch in enumerate(tqdm(train_dataloader)):\n",
654
+ " pixel_values = batch[\"pixel_values\"].to(device)\n",
655
+ " labels = batch[\"labels\"].to(device)\n",
656
+ "\n",
657
+ " optimizer.zero_grad()\n",
658
+ "\n",
659
+ " outputs = model(pixel_values=pixel_values, labels=labels)\n",
660
+ " loss, logits = outputs.loss, outputs.logits\n",
661
+ "\n",
662
+ " loss.backward()\n",
663
+ " optimizer.step()\n",
664
+ "\n",
665
+ " with torch.no_grad():\n",
666
+ " upsampled_logits = nn.functional.interpolate(\n",
667
+ " logits, size=labels.shape[-2:], mode=\"bilinear\", align_corners=False\n",
668
+ " )\n",
669
+ " predicted = upsampled_logits.argmax(dim=1)\n",
670
+ " metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())\n",
671
+ "\n",
672
+ " if index % 100 == 0:\n",
673
+ " metrics = metric.compute(num_labels=num_labels, ignore_index=255, reduce_labels=False)\n",
674
+ " print(f\"Epoch {epoch}/{EPOCHS} Batch {index}/{len(train_dataloader)} Loss {loss.item():.4f} Metrics {metrics}\")"
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "execution_count": 5,
680
+ "metadata": {},
681
+ "outputs": [
682
+ {
683
+ "name": "stderr",
684
+ "output_type": "stream",
685
+ "text": [
686
+ "Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n",
687
+ "Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n"
688
+ ]
689
+ },
690
+ {
691
+ "name": "stdout",
692
+ "output_type": "stream",
693
+ "text": [
694
+ "1000\n"
695
+ ]
696
+ }
697
+ ],
698
+ "source": [
699
+ "import numpy as np\n",
700
+ "\n",
701
+ "tokenizer = SegformerFeatureExtractor(reduce_labels=True)\n",
702
+ "\n",
703
+ "dataset = load_dataset(HUB_DIR, split=\"train\")\n",
704
+ "length = len(dataset)\n",
705
+ "print(length)\n",
706
+ "\n",
707
+ "encoded_dataset = tokenizer(\n",
708
+ " images=dataset[\"pixel_values\"], segmentation_maps=dataset[\"label\"], return_tensors=\"pt\"\n",
709
+ ")"
710
+ ]
711
+ },
712
+ {
713
+ "cell_type": "code",
714
+ "execution_count": 9,
715
+ "metadata": {},
716
+ "outputs": [],
717
+ "source": [
718
+ "pixel_values = encoded_dataset[\"pixel_values\"]\n",
719
+ "labels = encoded_dataset[\"labels\"]"
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "code",
724
+ "execution_count": 10,
725
+ "metadata": {},
726
+ "outputs": [
727
+ {
728
+ "data": {
729
+ "text/plain": [
730
+ "torch.Tensor"
731
+ ]
732
+ },
733
+ "execution_count": 10,
734
+ "metadata": {},
735
+ "output_type": "execute_result"
736
+ }
737
+ ],
738
+ "source": [
739
+ "type(pixel_values)"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "code",
744
+ "execution_count": 25,
745
+ "metadata": {},
746
+ "outputs": [],
747
+ "source": [
748
+ "from torch.utils.data import DataLoader, Dataset, random_split, Subset\n",
749
+ "\n",
750
+ "class SegmentationDataset(Dataset):\n",
751
+ " def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):\n",
752
+ " self.pixel_values = pixel_values\n",
753
+ " self.labels = labels\n",
754
+ " assert pixel_values.shape[0] == labels.shape[0]\n",
755
+ " self.length = pixel_values.shape[0]\n",
756
+ " print(f\"Created dataset with {self.length} samples\")\n",
757
+ " \n",
758
+ "\n",
759
+ " def __len__(self):\n",
760
+ " return self.length\n",
761
+ "\n",
762
+ "\n",
763
+ " def __getitem__(self, index):\n",
764
+ " image = self.pixel_values[index]\n",
765
+ " label = self.labels[index]\n",
766
+ "\n",
767
+ " encoded_inputs = BatchFeature({\"pixel_values\": image, \"labels\": label})\n",
768
+ "\n",
769
+ " return encoded_inputs"
770
+ ]
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "execution_count": 26,
775
+ "metadata": {},
776
+ "outputs": [
777
+ {
778
+ "name": "stdout",
779
+ "output_type": "stream",
780
+ "text": [
781
+ "Created dataset with 1000 samples\n"
782
+ ]
783
+ }
784
+ ],
785
+ "source": [
786
+ "segmentation_dataset = SegmentationDataset(pixel_values, labels)"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "execution_count": 27,
792
+ "metadata": {},
793
+ "outputs": [],
794
+ "source": [
795
+ "test = segmentation_dataset[0]"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": 28,
801
+ "metadata": {},
802
+ "outputs": [
803
+ {
804
+ "data": {
805
+ "text/plain": [
806
+ "{'pixel_values': tensor([[[ 0.0912, -0.1828, -0.1143, ..., -0.5253, -0.5424, -0.6623],\n",
807
+ " [-0.0116, -0.2342, -0.1486, ..., -0.5424, -0.6109, -0.7137],\n",
808
+ " [ 0.0398, -0.1828, -0.1314, ..., -0.5082, -0.6281, -0.7650],\n",
809
+ " ...,\n",
810
+ " [ 1.1529, 1.3927, 1.0331, ..., 0.4166, 0.3481, 0.3309],\n",
811
+ " [ 0.9474, 1.1358, 1.3070, ..., 0.5022, 0.3652, 0.3994],\n",
812
+ " [ 0.6049, 1.3413, 1.1358, ..., 1.2728, 0.6563, 0.8104]],\n",
813
+ "\n",
814
+ " [[ 0.3102, -0.2150, -0.3200, ..., -0.3901, -0.4426, -0.5651],\n",
815
+ " [ 0.2227, -0.2850, -0.3725, ..., -0.4076, -0.5126, -0.6176],\n",
816
+ " [ 0.2577, -0.2325, -0.3550, ..., -0.3725, -0.5301, -0.6702],\n",
817
+ " ...,\n",
818
+ " [ 1.1506, 1.3957, 1.0280, ..., 0.4678, 0.3978, 0.3803],\n",
819
+ " [ 0.9405, 1.1331, 1.3081, ..., 0.5553, 0.4153, 0.4503],\n",
820
+ " [ 0.5903, 1.3431, 1.1331, ..., 1.3431, 0.7129, 0.8704]],\n",
821
+ "\n",
822
+ " [[ 0.4788, -0.0267, -0.1312, ..., 0.0431, -0.0441, -0.2010],\n",
823
+ " [ 0.3916, -0.0790, -0.1835, ..., 0.0256, -0.1138, -0.2532],\n",
824
+ " [ 0.4265, -0.0267, -0.1661, ..., 0.0605, -0.1312, -0.3055],\n",
825
+ " ...,\n",
826
+ " [ 1.2805, 1.5245, 1.1585, ..., 0.6356, 0.5659, 0.5485],\n",
827
+ " [ 1.0714, 1.2631, 1.4374, ..., 0.7228, 0.5834, 0.6182],\n",
828
+ " [ 0.7228, 1.4722, 1.2631, ..., 1.5071, 0.8797, 1.0365]]]), 'labels': tensor([[17, 17, 17, ..., 17, 17, 17],\n",
829
+ " [17, 17, 17, ..., 17, 17, 17],\n",
830
+ " [17, 17, 17, ..., 17, 17, 17],\n",
831
+ " ...,\n",
832
+ " [ 1, 1, 1, ..., 1, 1, 1],\n",
833
+ " [ 1, 1, 1, ..., 1, 1, 1],\n",
834
+ " [ 1, 1, 1, ..., 1, 1, 1]])}"
835
+ ]
836
+ },
837
+ "execution_count": 28,
838
+ "metadata": {},
839
+ "output_type": "execute_result"
840
+ }
841
+ ],
842
+ "source": [
843
+ "test"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": 20,
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "data": {
853
+ "text/plain": [
854
+ "tensor([ 0, 1, 2, 4, 6, 9, 17, 18, 19, 23, 24, 25, 27, 31, 32])"
855
+ ]
856
+ },
857
+ "execution_count": 20,
858
+ "metadata": {},
859
+ "output_type": "execute_result"
860
+ }
861
+ ],
862
+ "source": [
863
+ "segmentation_dataset[0][1].squeeze().unique()"
864
+ ]
865
+ },
866
+ {
867
+ "cell_type": "code",
868
+ "execution_count": 22,
869
+ "metadata": {},
870
+ "outputs": [],
871
+ "source": [
872
+ "indices = np.arange(length)\n",
873
+ "train_indices, val_indices = random_split(indices, [int(length * 0.8), int(length * 0.2)])\n",
874
+ "\n",
875
+ "train_dataset = SegmentationDataset(encoded_dataset, train_indices)\n",
876
+ "val_dataset = SegmentationDataset(encoded_dataset, val_indices)"
877
+ ]
878
+ },
879
+ {
880
+ "cell_type": "code",
881
+ "execution_count": 26,
882
+ "metadata": {},
883
+ "outputs": [
884
+ {
885
+ "ename": "TypeError",
886
+ "evalue": "list indices must be integers or slices, not str",
887
+ "output_type": "error",
888
+ "traceback": [
889
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
890
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
891
+ "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 14'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000018?line=0'>1</a>\u001b[0m train_dataset[\u001b[39m0\u001b[39;49m]\n",
892
+ "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=11'>12</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=12'>13</a>\u001b[0m image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=13'>14</a>\u001b[0m label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=15'>16</a>\u001b[0m \u001b[39mreturn\u001b[39;00m image, label\n",
893
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
894
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
895
+ "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str"
896
+ ]
897
+ }
898
+ ],
899
+ "source": [
900
+ "train_dataset[0]"
901
+ ]
902
+ },
903
+ {
904
+ "cell_type": "code",
905
+ "execution_count": 23,
906
+ "metadata": {},
907
+ "outputs": [],
908
+ "source": [
909
+ "train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n",
910
+ "valid_dataloader = DataLoader(val_dataset, batch_size=2)"
911
+ ]
912
+ },
913
+ {
914
+ "cell_type": "code",
915
+ "execution_count": 24,
916
+ "metadata": {},
917
+ "outputs": [
918
+ {
919
+ "ename": "TypeError",
920
+ "evalue": "list indices must be integers or slices, not str",
921
+ "output_type": "error",
922
+ "traceback": [
923
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
924
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
925
+ "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 15'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=0'>1</a>\u001b[0m batch \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(\u001b[39miter\u001b[39;49m(train_dataloader))\n",
926
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:530\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=527'>528</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=528'>529</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset()\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=529'>530</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=530'>531</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=531'>532</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=532'>533</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=533'>534</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
927
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:570\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=567'>568</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=568'>569</a>\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=569'>570</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=570'>571</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=571'>572</a>\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data)\n",
928
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=46'>47</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=47'>48</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=48'>49</a>\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=49'>50</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=50'>51</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
929
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=46'>47</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=47'>48</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=48'>49</a>\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=49'>50</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=50'>51</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
930
+ "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=11'>12</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=12'>13</a>\u001b[0m image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=13'>14</a>\u001b[0m label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=15'>16</a>\u001b[0m \u001b[39mreturn\u001b[39;00m image, label\n",
931
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
932
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
933
+ "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str"
934
+ ]
935
+ }
936
+ ],
937
+ "source": [
938
+ "batch = next(iter(train_dataloader))"
939
+ ]
940
+ },
941
+ {
942
+ "cell_type": "code",
943
+ "execution_count": null,
944
+ "metadata": {},
945
+ "outputs": [],
946
+ "source": []
947
+ }
948
+ ],
949
+ "metadata": {
950
+ "interpreter": {
951
+ "hash": "19a4294f02aad727716e8ed0f765e04171ea0ecb4c129d0b0eebf53be4c3a095"
952
+ },
953
+ "kernelspec": {
954
+ "display_name": "Python 3.8.13 ('segformer')",
955
+ "language": "python",
956
+ "name": "python3"
957
+ },
958
+ "language_info": {
959
+ "codemirror_mode": {
960
+ "name": "ipython",
961
+ "version": 3
962
+ },
963
+ "file_extension": ".py",
964
+ "mimetype": "text/x-python",
965
+ "name": "python",
966
+ "nbconvert_exporter": "python",
967
+ "pygments_lexer": "ipython3",
968
+ "version": "3.8.13"
969
+ },
970
+ "orig_nbformat": 4
971
+ },
972
+ "nbformat": 4,
973
+ "nbformat_minor": 2
974
+ }