test notebook
Browse files- finetuning.ipynb +974 -0
finetuning.ipynb
ADDED
@@ -0,0 +1,974 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import json\n",
|
10 |
+
"import pytorch_lightning as pl\n",
|
11 |
+
"import torch\n",
|
12 |
+
"import torchmetrics\n",
|
13 |
+
"\n",
|
14 |
+
"from datasets import load_dataset, load_metric\n",
|
15 |
+
"\n",
|
16 |
+
"from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation\n",
|
17 |
+
"\n",
|
18 |
+
"from torch import nn\n",
|
19 |
+
"from torch.utils.data import DataLoader, Dataset, random_split\n",
|
20 |
+
"\n",
|
21 |
+
"from tqdm.notebook import tqdm"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 4,
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"class SemanticSegmentationDataset(Dataset):\n",
|
31 |
+
" \"\"\"Image segmentation datasets.\"\"\"\n",
|
32 |
+
"\n",
|
33 |
+
" def __init__(\n",
|
34 |
+
" self, \n",
|
35 |
+
" dataset: torch.utils.data.dataset.Subset, \n",
|
36 |
+
" feature_extractor = SegformerFeatureExtractor(reduce_labels=True),\n",
|
37 |
+
" ):\n",
|
38 |
+
" \"\"\"\n",
|
39 |
+
" Initialize the dataset with the given feature extractor and split.\n",
|
40 |
+
"\n",
|
41 |
+
" Parameters\n",
|
42 |
+
" ----------\n",
|
43 |
+
" hub_dir : Dataset\n",
|
44 |
+
" The dataset to use.\n",
|
45 |
+
" feature_extractor : FeatureExtractor, optional\n",
|
46 |
+
" The feature extractor to use. The default is SegformerFeatureExtractor.\n",
|
47 |
+
" \"\"\"\n",
|
48 |
+
" self.dataset = dataset\n",
|
49 |
+
" self.feature_extractor = feature_extractor\n",
|
50 |
+
" self.length = len(self.dataset)\n",
|
51 |
+
" print(f\"Loaded {self.length} samples.\")\n",
|
52 |
+
"\n",
|
53 |
+
"\n",
|
54 |
+
" def __len__(self):\n",
|
55 |
+
" \"\"\"Return the number of samples in the dataset.\"\"\"\n",
|
56 |
+
" return self.length\n",
|
57 |
+
"\n",
|
58 |
+
"\n",
|
59 |
+
" def __getitem__(self, index: int):\n",
|
60 |
+
" \"\"\"Get the sample at the given index.\"\"\"\n",
|
61 |
+
" image = self.dataset[index][\"pixel_values\"]\n",
|
62 |
+
" label = self.dataset[index][\"label\"]\n",
|
63 |
+
"\n",
|
64 |
+
" encoded_inputs = self.feature_extractor(image, label, return_tensors=\"pt\")\n",
|
65 |
+
"\n",
|
66 |
+
" for k, v in encoded_inputs.items():\n",
|
67 |
+
" encoded_inputs[k].squeeze_() # remove batch dimension\n",
|
68 |
+
"\n",
|
69 |
+
" return encoded_inputs"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": 3,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"BATCH_SIZE = 32\n",
|
79 |
+
"HUB_DIR = \"segments/sidewalk-semantic\"\n",
|
80 |
+
"EPOCHS = 200"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 5,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [
|
88 |
+
{
|
89 |
+
"name": "stderr",
|
90 |
+
"output_type": "stream",
|
91 |
+
"text": [
|
92 |
+
"Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n",
|
93 |
+
"Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"name": "stdout",
|
98 |
+
"output_type": "stream",
|
99 |
+
"text": [
|
100 |
+
"Loaded 800 samples.\n",
|
101 |
+
"Loaded 200 samples.\n"
|
102 |
+
]
|
103 |
+
}
|
104 |
+
],
|
105 |
+
"source": [
|
106 |
+
"dataset = load_dataset(HUB_DIR, split=\"train\")\n",
|
107 |
+
"\n",
|
108 |
+
"train_dataset, val_dataset = random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])\n",
|
109 |
+
"train_dataset = SemanticSegmentationDataset(train_dataset)\n",
|
110 |
+
"val_dataset = SemanticSegmentationDataset(val_dataset)"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": 5,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
|
120 |
+
"val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 6,
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [
|
128 |
+
{
|
129 |
+
"name": "stdout",
|
130 |
+
"output_type": "stream",
|
131 |
+
"text": [
|
132 |
+
"pixel_values torch.Size([32, 3, 512, 512])\n",
|
133 |
+
"labels torch.Size([32, 512, 512])\n"
|
134 |
+
]
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"source": [
|
138 |
+
"batch = next(iter(train_dataloader))\n",
|
139 |
+
"\n",
|
140 |
+
"for k, v in batch.items():\n",
|
141 |
+
" print(k, v.shape)"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": 7,
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [],
|
149 |
+
"source": [
|
150 |
+
"# class SidewalkSegmentationModel(pl.LightningModule):\n",
|
151 |
+
"# def __init__(self, num_classes: int, learning_rate: float = 6e-5):\n",
|
152 |
+
"# super().__init__()\n",
|
153 |
+
"# self.model = SegformerForSemanticSegmentation.from_pretrained(\n",
|
154 |
+
"# \"nvidia/mit-b0\", num_labels=num_classes, id2label=id2label, label2id=label2id,\n",
|
155 |
+
"# )\n",
|
156 |
+
"# self.learning_rate = learning_rate\n",
|
157 |
+
"# self.metric = load_metric(\"mean_iou\")\n",
|
158 |
+
"\n",
|
159 |
+
" \n",
|
160 |
+
"# def forward(self, *args, **kwargs):\n",
|
161 |
+
"# return self.model(*args, **kwargs)\n",
|
162 |
+
"\n",
|
163 |
+
" \n",
|
164 |
+
"# def training_step(self, batch, batch_idx):\n",
|
165 |
+
"# pixel_values = batch[\"pixel_values\"]\n",
|
166 |
+
"# labels = batch[\"labels\"]\n",
|
167 |
+
"\n",
|
168 |
+
"# outputs = self(pixel_values=pixel_values, labels=labels)\n",
|
169 |
+
"# loss, logits = outputs.loss, outputs.logits\n",
|
170 |
+
"\n",
|
171 |
+
" \n",
|
172 |
+
"# def configure_optimizers(self) -> torch.optim.AdamW:\n",
|
173 |
+
"# \"\"\"\n",
|
174 |
+
"# Configure the optimizer.\n",
|
175 |
+
"# Returns\n",
|
176 |
+
"# -------\n",
|
177 |
+
"# torch.optim.AdamW\n",
|
178 |
+
"# Optimizer for the model\n",
|
179 |
+
"# \"\"\"\n",
|
180 |
+
"# return torch.optim.AdamW(model.parameters(), lr=self.learning_rate)"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 11,
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [
|
188 |
+
{
|
189 |
+
"name": "stdout",
|
190 |
+
"output_type": "stream",
|
191 |
+
"text": [
|
192 |
+
"{0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}\n"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"name": "stderr",
|
197 |
+
"output_type": "stream",
|
198 |
+
"text": [
|
199 |
+
"Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']\n",
|
200 |
+
"- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
201 |
+
"- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
202 |
+
"Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.classifier.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.1.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_c.2.proj.bias']\n",
|
203 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
204 |
+
]
|
205 |
+
}
|
206 |
+
],
|
207 |
+
"source": [
|
208 |
+
"id2label_file = json.load(open(\"id2label.json\", \"r\"))\n",
|
209 |
+
"id2label = {int(k): v for k, v in id2label_file.items()}\n",
|
210 |
+
"print(id2label)\n",
|
211 |
+
"label2id = {v: k for k, v in id2label_file.items()}\n",
|
212 |
+
"num_labels = len(id2label)\n",
|
213 |
+
"\n",
|
214 |
+
"model = SegformerForSemanticSegmentation.from_pretrained(\n",
|
215 |
+
" \"nvidia/mit-b0\", num_labels=num_labels, id2label=id2label, label2id=label2id,\n",
|
216 |
+
")"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": 9,
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [],
|
224 |
+
"source": [
|
225 |
+
"metric = load_metric(\"mean_iou\")"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": 10,
|
231 |
+
"metadata": {},
|
232 |
+
"outputs": [
|
233 |
+
{
|
234 |
+
"data": {
|
235 |
+
"application/vnd.jupyter.widget-view+json": {
|
236 |
+
"model_id": "091f8e1587f64625a4bbbf04f13a840e",
|
237 |
+
"version_major": 2,
|
238 |
+
"version_minor": 0
|
239 |
+
},
|
240 |
+
"text/plain": [
|
241 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
"metadata": {},
|
245 |
+
"output_type": "display_data"
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"name": "stderr",
|
249 |
+
"output_type": "stream",
|
250 |
+
"text": [
|
251 |
+
"/home/chainyo/.cache/huggingface/modules/datasets_modules/metrics/mean_iou/d4add40cf977cdd73590b5873fa830f3f13adb678f6777a29fb07b7c81d14342/mean_iou.py:259: RuntimeWarning: invalid value encountered in true_divide\n",
|
252 |
+
" acc = total_area_intersect / total_area_label\n"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"name": "stdout",
|
257 |
+
"output_type": "stream",
|
258 |
+
"text": [
|
259 |
+
"Epoch 0/200 Batch 0/25 Loss 3.5735 Metrics {'mean_iou': 0.005477220659286406, 'mean_accuracy': 0.03572801337234697, 'overall_accuracy': 0.0286087999162653, 'per_category_iou': array([2.25093889e-02, 1.39043249e-03, 7.61652063e-03, 1.08658706e-02,\n",
|
260 |
+
" 1.00475237e-02, 0.00000000e+00, 2.03193907e-04, 1.38439962e-03,\n",
|
261 |
+
" 3.06388194e-05, 2.21495741e-03, 2.81951515e-05, 0.00000000e+00,\n",
|
262 |
+
" 0.00000000e+00, 1.13529929e-04, 0.00000000e+00, 0.00000000e+00,\n",
|
263 |
+
" 0.00000000e+00, 7.05787860e-03, 8.48350742e-03, 2.38476991e-03,\n",
|
264 |
+
" 1.76276224e-03, 2.17775404e-03, 0.00000000e+00, 1.04358386e-03,\n",
|
265 |
+
" 4.58714414e-04, 6.41413308e-05, 1.11431787e-04, 7.99247870e-02,\n",
|
266 |
+
" 8.28409255e-03, 2.07198485e-03, 6.12199013e-04, 6.53992687e-03,\n",
|
267 |
+
" 1.42611461e-02, 5.93919660e-05, 0.00000000e+00]), 'per_category_accuracy': array([2.49789663e-02, 1.40417891e-03, 5.70403118e-02, 1.19771803e-02,\n",
|
268 |
+
" 1.35858556e-02, nan, 2.43287079e-04, 1.42133234e-02,\n",
|
269 |
+
" 9.06344411e-03, 2.86141687e-03, 1.51057402e-03, nan,\n",
|
270 |
+
" nan, 1.54786781e-04, 0.00000000e+00, 0.00000000e+00,\n",
|
271 |
+
" nan, 7.59912501e-03, 1.08030661e-01, 7.46983779e-03,\n",
|
272 |
+
" 3.63612647e-03, 2.03376823e-01, nan, 1.34638923e-02,\n",
|
273 |
+
" 4.23971037e-03, 2.44969379e-03, 2.61780105e-02, 1.44916888e-01,\n",
|
274 |
+
" 1.09119869e-02, 2.12033947e-03, 5.77414410e-02, 8.69919527e-02,\n",
|
275 |
+
" 1.99881736e-01, 2.00708383e-02, nan])}\n"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"data": {
|
280 |
+
"application/vnd.jupyter.widget-view+json": {
|
281 |
+
"model_id": "8b14f2d3fe2044bf8940697dfce1f2d9",
|
282 |
+
"version_major": 2,
|
283 |
+
"version_minor": 0
|
284 |
+
},
|
285 |
+
"text/plain": [
|
286 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
"metadata": {},
|
290 |
+
"output_type": "display_data"
|
291 |
+
},
|
292 |
+
{
|
293 |
+
"name": "stdout",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"Epoch 1/200 Batch 0/25 Loss 2.3349 Metrics {'mean_iou': 0.08249196196840773, 'mean_accuracy': 0.12242777101928329, 'overall_accuracy': 0.5340892205419953, 'per_category_iou': array([3.12479410e-01, 6.01364321e-01, 1.11562075e-02, 2.09116315e-02,\n",
|
297 |
+
" 1.95069293e-02, 0.00000000e+00, 1.09903909e-04, 1.13759900e-03,\n",
|
298 |
+
" 1.31593453e-03, 4.16974851e-01, 9.43479560e-04, 0.00000000e+00,\n",
|
299 |
+
" 0.00000000e+00, 2.71726824e-05, 7.02479634e-05, 1.71339744e-06,\n",
|
300 |
+
" 1.01240638e-04, 3.64420274e-01, 5.22553547e-03, 1.01179313e-03,\n",
|
301 |
+
" 5.40034112e-03, 6.14980455e-04, 0.00000000e+00, 1.37384839e-03,\n",
|
302 |
+
" 1.58339239e-03, 2.28333709e-05, 4.11671538e-05, 5.37431493e-01,\n",
|
303 |
+
" 6.31855647e-02, 4.88071958e-01, 5.19480641e-03, 4.67615947e-03,\n",
|
304 |
+
" 2.28171525e-02, 4.67269054e-05, 0.00000000e+00]), 'per_category_accuracy': array([4.06453404e-01, 7.78244778e-01, 2.33772733e-02, 2.33628638e-02,\n",
|
305 |
+
" 2.53283555e-02, nan, 1.11846810e-04, 3.42729695e-03,\n",
|
306 |
+
" 9.61392116e-03, 6.59758916e-01, 1.07169352e-02, 0.00000000e+00,\n",
|
307 |
+
" 0.00000000e+00, 1.71831147e-04, 1.20609014e-04, 1.32753642e-04,\n",
|
308 |
+
" 7.55138223e-03, 5.08304753e-01, 2.45514876e-02, 1.15429656e-03,\n",
|
309 |
+
" 6.00626887e-03, 1.89040602e-02, 0.00000000e+00, 3.03657284e-03,\n",
|
310 |
+
" 3.07788162e-03, 1.93985207e-04, 1.35743374e-03, 8.37114885e-01,\n",
|
311 |
+
" 7.86780250e-02, 5.18267366e-01, 1.14241257e-02, 1.51054789e-02,\n",
|
312 |
+
" 6.31875993e-02, 1.38005816e-03, nan])}\n"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"data": {
|
317 |
+
"application/vnd.jupyter.widget-view+json": {
|
318 |
+
"model_id": "1bcd35da7d4a43a0958776974bfa7918",
|
319 |
+
"version_major": 2,
|
320 |
+
"version_minor": 0
|
321 |
+
},
|
322 |
+
"text/plain": [
|
323 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
"metadata": {},
|
327 |
+
"output_type": "display_data"
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"name": "stderr",
|
331 |
+
"output_type": "stream",
|
332 |
+
"text": [
|
333 |
+
"/home/chainyo/.cache/huggingface/modules/datasets_modules/metrics/mean_iou/d4add40cf977cdd73590b5873fa830f3f13adb678f6777a29fb07b7c81d14342/mean_iou.py:258: RuntimeWarning: invalid value encountered in true_divide\n",
|
334 |
+
" iou = total_area_intersect / total_area_union\n"
|
335 |
+
]
|
336 |
+
},
|
337 |
+
{
|
338 |
+
"name": "stdout",
|
339 |
+
"output_type": "stream",
|
340 |
+
"text": [
|
341 |
+
"Epoch 2/200 Batch 0/25 Loss 1.9441 Metrics {'mean_iou': 0.1115145961827685, 'mean_accuracy': 0.1571059701593332, 'overall_accuracy': 0.6731429879739427, 'per_category_iou': array([4.45592658e-01, 7.27905738e-01, 3.53301085e-05, 1.04372074e-02,\n",
|
342 |
+
" 7.95093029e-03, nan, 1.38316668e-06, 0.00000000e+00,\n",
|
343 |
+
" 0.00000000e+00, 4.80050947e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
344 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
345 |
+
" 0.00000000e+00, 4.95525330e-01, 0.00000000e+00, 1.11859236e-05,\n",
|
346 |
+
" 8.15231791e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
347 |
+
" 2.30751037e-06, 0.00000000e+00, 0.00000000e+00, 6.24028521e-01,\n",
|
348 |
+
" 1.05303497e-01, 7.82230424e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
349 |
+
" 9.05398708e-04, 0.00000000e+00, nan]), 'per_category_accuracy': array([7.33079709e-01, 9.03717795e-01, 3.57307878e-05, 1.05455168e-02,\n",
|
350 |
+
" 8.21302511e-03, nan, 1.38319984e-06, 0.00000000e+00,\n",
|
351 |
+
" 0.00000000e+00, 8.08957376e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
352 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
353 |
+
" 0.00000000e+00, 8.13886465e-01, 0.00000000e+00, 1.11898068e-05,\n",
|
354 |
+
" 8.19780315e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
355 |
+
" 2.30767678e-06, 0.00000000e+00, 0.00000000e+00, 9.49630788e-01,\n",
|
356 |
+
" 1.14320143e-01, 8.41172201e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
357 |
+
" 9.22565695e-04, 0.00000000e+00, nan])}\n"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"data": {
|
362 |
+
"application/vnd.jupyter.widget-view+json": {
|
363 |
+
"model_id": "f3ef333a6137481180b4ec59efd1673e",
|
364 |
+
"version_major": 2,
|
365 |
+
"version_minor": 0
|
366 |
+
},
|
367 |
+
"text/plain": [
|
368 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
"metadata": {},
|
372 |
+
"output_type": "display_data"
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"name": "stdout",
|
376 |
+
"output_type": "stream",
|
377 |
+
"text": [
|
378 |
+
"Epoch 3/200 Batch 0/25 Loss 1.7086 Metrics {'mean_iou': 0.13337084618347653, 'mean_accuracy': 0.17885172041247846, 'overall_accuracy': 0.7166871222915292, 'per_category_iou': array([0.49122674, 0.77341676, 0. , 0.1407711 , 0.00553198,\n",
|
379 |
+
" nan, 0. , 0. , 0. , 0.55152752,\n",
|
380 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
381 |
+
" 0. , 0. , 0.51584332, 0. , 0. ,\n",
|
382 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
383 |
+
" 0. , 0. , 0.68907817, 0.42107346, 0.81276888,\n",
|
384 |
+
" 0. , 0. , 0. , 0. , nan]), 'per_category_accuracy': array([0.82412545, 0.91420412, 0. , 0.14242155, 0.00561311,\n",
|
385 |
+
" nan, 0. , 0. , 0. , 0.84759725,\n",
|
386 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
387 |
+
" 0. , 0. , 0.85361568, 0. , 0. ,\n",
|
388 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
389 |
+
" 0. , 0. , 0.93567573, 0.47723124, 0.90162263,\n",
|
390 |
+
" 0. , 0. , 0. , 0. , nan])}\n"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"data": {
|
395 |
+
"application/vnd.jupyter.widget-view+json": {
|
396 |
+
"model_id": "384340072f794e70bcbbc74dacc50c21",
|
397 |
+
"version_major": 2,
|
398 |
+
"version_minor": 0
|
399 |
+
},
|
400 |
+
"text/plain": [
|
401 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
402 |
+
]
|
403 |
+
},
|
404 |
+
"metadata": {},
|
405 |
+
"output_type": "display_data"
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"name": "stdout",
|
409 |
+
"output_type": "stream",
|
410 |
+
"text": [
|
411 |
+
"Epoch 4/200 Batch 0/25 Loss 1.4155 Metrics {'mean_iou': 0.15564168770290954, 'mean_accuracy': 0.20019284726893685, 'overall_accuracy': 0.7574203051314757, 'per_category_iou': array([5.71505637e-01, 8.01399129e-01, 0.00000000e+00, 4.92403441e-01,\n",
|
412 |
+
" 5.41554099e-03, nan, 2.31577543e-07, 0.00000000e+00,\n",
|
413 |
+
" 0.00000000e+00, 5.86567051e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
414 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
415 |
+
" 0.00000000e+00, 5.51951880e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
416 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
417 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.24133440e-01,\n",
|
418 |
+
" 5.67064833e-01, 8.35733013e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
419 |
+
" 1.49737193e-06, 0.00000000e+00, nan]), 'per_category_accuracy': array([8.47391615e-01, 9.31621860e-01, 0.00000000e+00, 5.38104498e-01,\n",
|
420 |
+
" 5.46801709e-03, nan, 2.31577972e-07, 0.00000000e+00,\n",
|
421 |
+
" 0.00000000e+00, 8.70377721e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
422 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
423 |
+
" 0.00000000e+00, 8.79001612e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
424 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
425 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.26588036e-01,\n",
|
426 |
+
" 6.90579673e-01, 9.17229200e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
427 |
+
" 1.49737361e-06, 0.00000000e+00, nan])}\n"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"data": {
|
432 |
+
"application/vnd.jupyter.widget-view+json": {
|
433 |
+
"model_id": "428849c8d6ce414887a0661d3d4625da",
|
434 |
+
"version_major": 2,
|
435 |
+
"version_minor": 0
|
436 |
+
},
|
437 |
+
"text/plain": [
|
438 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
"metadata": {},
|
442 |
+
"output_type": "display_data"
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"name": "stdout",
|
446 |
+
"output_type": "stream",
|
447 |
+
"text": [
|
448 |
+
"Epoch 5/200 Batch 0/25 Loss 1.2700 Metrics {'mean_iou': 0.16264865782481297, 'mean_accuracy': 0.20757696789528982, 'overall_accuracy': 0.7689239961674764, 'per_category_iou': array([0.58414589, 0.81498541, 0. , 0.57467831, 0.00689438,\n",
|
449 |
+
" nan, 0. , 0. , 0. , 0.61111453,\n",
|
450 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
451 |
+
" 0. , 0. , 0.55374014, 0. , 0. ,\n",
|
452 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
453 |
+
" 0. , 0. , 0.74374917, 0.6398338 , 0.83826408,\n",
|
454 |
+
" 0. , 0. , 0. , 0. , nan]), 'per_category_accuracy': array([0.86273737, 0.93244115, 0. , 0.66435561, 0.00696605,\n",
|
455 |
+
" nan, 0. , 0. , 0. , 0.87784737,\n",
|
456 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
457 |
+
" 0. , 0. , 0.87298997, 0. , 0. ,\n",
|
458 |
+
" 0. , 0. , 0. , 0. , 0. ,\n",
|
459 |
+
" 0. , 0. , 0.92712481, 0.78796366, 0.91761394,\n",
|
460 |
+
" 0. , 0. , 0. , 0. , nan])}\n"
|
461 |
+
]
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"data": {
|
465 |
+
"application/vnd.jupyter.widget-view+json": {
|
466 |
+
"model_id": "1577eed162ec4162a69493365c950329",
|
467 |
+
"version_major": 2,
|
468 |
+
"version_minor": 0
|
469 |
+
},
|
470 |
+
"text/plain": [
|
471 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
"metadata": {},
|
475 |
+
"output_type": "display_data"
|
476 |
+
},
|
477 |
+
{
|
478 |
+
"name": "stdout",
|
479 |
+
"output_type": "stream",
|
480 |
+
"text": [
|
481 |
+
"Epoch 6/200 Batch 0/25 Loss 1.1401 Metrics {'mean_iou': 0.16748948766093116, 'mean_accuracy': 0.21181445128118748, 'overall_accuracy': 0.7795023598828317, 'per_category_iou': array([6.09680179e-01, 8.31918538e-01, 0.00000000e+00, 6.34236889e-01,\n",
|
482 |
+
" 1.24257235e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
|
483 |
+
" 0.00000000e+00, 6.35402026e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
484 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
485 |
+
" 0.00000000e+00, 5.64656796e-01, 0.00000000e+00, 3.65611521e-06,\n",
|
486 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
487 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.53253595e-01,\n",
|
488 |
+
" 6.35493134e-01, 8.50082556e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
489 |
+
" 0.00000000e+00, 0.00000000e+00, nan]), 'per_category_accuracy': array([8.81516256e-01, 9.41665624e-01, 0.00000000e+00, 7.39806944e-01,\n",
|
490 |
+
" 1.26589939e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
|
491 |
+
" 0.00000000e+00, 8.90136686e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
492 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
493 |
+
" 0.00000000e+00, 8.94297220e-01, 0.00000000e+00, 3.65611521e-06,\n",
|
494 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
495 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.26219754e-01,\n",
|
496 |
+
" 7.75692895e-01, 9.27878865e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
497 |
+
" 0.00000000e+00, 0.00000000e+00, nan])}\n"
|
498 |
+
]
|
499 |
+
},
|
500 |
+
{
|
501 |
+
"data": {
|
502 |
+
"application/vnd.jupyter.widget-view+json": {
|
503 |
+
"model_id": "169ae321c5ac4a4aab54a816c05035f2",
|
504 |
+
"version_major": 2,
|
505 |
+
"version_minor": 0
|
506 |
+
},
|
507 |
+
"text/plain": [
|
508 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
509 |
+
]
|
510 |
+
},
|
511 |
+
"metadata": {},
|
512 |
+
"output_type": "display_data"
|
513 |
+
},
|
514 |
+
{
|
515 |
+
"name": "stdout",
|
516 |
+
"output_type": "stream",
|
517 |
+
"text": [
|
518 |
+
"Epoch 7/200 Batch 0/25 Loss 1.0940 Metrics {'mean_iou': 0.17137856945919164, 'mean_accuracy': 0.21495584433690398, 'overall_accuracy': 0.7881396456781556, 'per_category_iou': array([6.32120417e-01, 8.39814384e-01, 0.00000000e+00, 6.53276797e-01,\n",
|
519 |
+
" 2.36145329e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
|
520 |
+
" 0.00000000e+00, 6.44513271e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
521 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
522 |
+
" 0.00000000e+00, 5.70917228e-01, 0.00000000e+00, 5.49670921e-05,\n",
|
523 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
524 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.63117039e-01,\n",
|
525 |
+
" 6.70860701e-01, 8.57197972e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
526 |
+
" 5.48363063e-06, 0.00000000e+00, nan]), 'per_category_accuracy': array([9.08405231e-01, 9.45289137e-01, 0.00000000e+00, 7.43501061e-01,\n",
|
527 |
+
" 2.43273947e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
|
528 |
+
" 0.00000000e+00, 8.96271845e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
529 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
530 |
+
" 0.00000000e+00, 8.93162955e-01, 0.00000000e+00, 5.49677426e-05,\n",
|
531 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
532 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.25997998e-01,\n",
|
533 |
+
" 8.26946192e-01, 9.29580599e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
534 |
+
" 5.48370681e-06, 0.00000000e+00, nan])}\n"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"data": {
|
539 |
+
"application/vnd.jupyter.widget-view+json": {
|
540 |
+
"model_id": "544f85de17a04a3ca74001e2fbbc6bc9",
|
541 |
+
"version_major": 2,
|
542 |
+
"version_minor": 0
|
543 |
+
},
|
544 |
+
"text/plain": [
|
545 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
546 |
+
]
|
547 |
+
},
|
548 |
+
"metadata": {},
|
549 |
+
"output_type": "display_data"
|
550 |
+
},
|
551 |
+
{
|
552 |
+
"name": "stdout",
|
553 |
+
"output_type": "stream",
|
554 |
+
"text": [
|
555 |
+
"Epoch 8/200 Batch 0/25 Loss 0.9463 Metrics {'mean_iou': 0.17478409568512296, 'mean_accuracy': 0.21773480038212398, 'overall_accuracy': 0.7927706422623841, 'per_category_iou': array([6.29223165e-01, 8.42392089e-01, 3.52355169e-04, 6.80396607e-01,\n",
|
556 |
+
" 5.21568283e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
|
557 |
+
" 0.00000000e+00, 6.55000862e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
558 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
559 |
+
" 0.00000000e+00, 5.75092138e-01, 0.00000000e+00, 7.12203670e-04,\n",
|
560 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
561 |
+
" 5.25684195e-05, 0.00000000e+00, 0.00000000e+00, 7.79574322e-01,\n",
|
562 |
+
" 6.87965551e-01, 8.64953847e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
563 |
+
" 2.62043422e-06, 0.00000000e+00, nan]), 'per_category_accuracy': array([9.02830276e-01, 9.46588697e-01, 3.52355621e-04, 7.80701400e-01,\n",
|
564 |
+
" 5.56946042e-02, nan, 0.00000000e+00, 0.00000000e+00,\n",
|
565 |
+
" 0.00000000e+00, 8.98895404e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
566 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
567 |
+
" 0.00000000e+00, 8.98825852e-01, 0.00000000e+00, 7.12358991e-04,\n",
|
568 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
569 |
+
" 5.25687464e-05, 0.00000000e+00, 0.00000000e+00, 9.33269035e-01,\n",
|
570 |
+
" 8.30703600e-01, 9.36619641e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
571 |
+
" 2.62043913e-06, 0.00000000e+00, nan])}\n"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"data": {
|
576 |
+
"application/vnd.jupyter.widget-view+json": {
|
577 |
+
"model_id": "a88b7f1699074168b624caaad5466071",
|
578 |
+
"version_major": 2,
|
579 |
+
"version_minor": 0
|
580 |
+
},
|
581 |
+
"text/plain": [
|
582 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
"metadata": {},
|
586 |
+
"output_type": "display_data"
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"name": "stdout",
|
590 |
+
"output_type": "stream",
|
591 |
+
"text": [
|
592 |
+
"Epoch 9/200 Batch 0/25 Loss 1.1287 Metrics {'mean_iou': 0.17933414157617142, 'mean_accuracy': 0.22211677037195932, 'overall_accuracy': 0.7957460527936768, 'per_category_iou': array([6.42468748e-01, 8.51322031e-01, 4.38690416e-02, 6.83716408e-01,\n",
|
593 |
+
" 1.05816211e-01, nan, 3.89710724e-04, 0.00000000e+00,\n",
|
594 |
+
" 0.00000000e+00, 6.68283305e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
595 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
596 |
+
" 0.00000000e+00, 5.86580529e-01, 0.00000000e+00, 1.00559462e-03,\n",
|
597 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
598 |
+
" 1.15812990e-04, 0.00000000e+00, 0.00000000e+00, 7.77224737e-01,\n",
|
599 |
+
" 6.87868711e-01, 8.69336423e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
600 |
+
" 2.94092022e-05, 0.00000000e+00, nan]), 'per_category_accuracy': array([9.05123309e-01, 9.50701912e-01, 4.38913906e-02, 7.96283141e-01,\n",
|
601 |
+
" 1.19598233e-01, nan, 3.89758329e-04, 0.00000000e+00,\n",
|
602 |
+
" 0.00000000e+00, 8.97971821e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
603 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
604 |
+
" 0.00000000e+00, 9.09473869e-01, 0.00000000e+00, 1.00721614e-03,\n",
|
605 |
+
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
|
606 |
+
" 1.15831636e-04, 0.00000000e+00, 0.00000000e+00, 9.30939519e-01,\n",
|
607 |
+
" 8.38613121e-01, 9.35714875e-01, 0.00000000e+00, 0.00000000e+00,\n",
|
608 |
+
" 2.94260342e-05, 0.00000000e+00, nan])}\n"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"data": {
|
613 |
+
"application/vnd.jupyter.widget-view+json": {
|
614 |
+
"model_id": "569b4a3ebdd24be7811709fc8d588e30",
|
615 |
+
"version_major": 2,
|
616 |
+
"version_minor": 0
|
617 |
+
},
|
618 |
+
"text/plain": [
|
619 |
+
" 0%| | 0/25 [00:00<?, ?it/s]"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
"metadata": {},
|
623 |
+
"output_type": "display_data"
|
624 |
+
},
|
625 |
+
{
|
626 |
+
"ename": "KeyboardInterrupt",
|
627 |
+
"evalue": "",
|
628 |
+
"output_type": "error",
|
629 |
+
"traceback": [
|
630 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
631 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
632 |
+
"\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 10'\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=24'>25</a>\u001b[0m metric\u001b[39m.\u001b[39madd_batch(predictions\u001b[39m=\u001b[39mpredicted\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy(), references\u001b[39m=\u001b[39mlabels\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy())\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=26'>27</a>\u001b[0m \u001b[39mif\u001b[39;00m index \u001b[39m%\u001b[39m \u001b[39m100\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=27'>28</a>\u001b[0m metrics \u001b[39m=\u001b[39m metric\u001b[39m.\u001b[39;49mcompute(num_labels\u001b[39m=\u001b[39;49mnum_labels, ignore_index\u001b[39m=\u001b[39;49m\u001b[39m255\u001b[39;49m, reduce_labels\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=28'>29</a>\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEpoch \u001b[39m\u001b[39m{\u001b[39;00mepoch\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00mEPOCHS\u001b[39m}\u001b[39;00m\u001b[39m Batch \u001b[39m\u001b[39m{\u001b[39;00mindex\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mlen\u001b[39m(train_dataloader)\u001b[39m}\u001b[39;00m\u001b[39m Loss \u001b[39m\u001b[39m{\u001b[39;00mloss\u001b[39m.\u001b[39mitem()\u001b[39m:\u001b[39;00m\u001b[39m.4f\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m Metrics \u001b[39m\u001b[39m{\u001b[39;00mmetrics\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
|
633 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36mMetric.compute\u001b[0;34m(self, predictions, references, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=424'>425</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=425'>426</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=427'>428</a>\u001b[0m inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=428'>429</a>\u001b[0m \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=429'>430</a>\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n",
|
634 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=424'>425</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=425'>426</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=427'>428</a>\u001b[0m inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=428'>429</a>\u001b[0m \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=429'>430</a>\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n",
|
635 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2124\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2121'>2122</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, key): \u001b[39m# noqa: F811\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2122'>2123</a>\u001b[0m \u001b[39m\"\"\"Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).\"\"\"\u001b[39;00m\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2123'>2124</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_getitem(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2124'>2125</a>\u001b[0m key,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2125'>2126</a>\u001b[0m )\n",
|
636 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2109\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, decoded, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2106'>2107</a>\u001b[0m formatter \u001b[39m=\u001b[39m get_formatter(format_type, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures, decoded\u001b[39m=\u001b[39mdecoded, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mformat_kwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2107'>2108</a>\u001b[0m pa_subtable \u001b[39m=\u001b[39m query_table(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_data, key, indices\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m)\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2108'>2109</a>\u001b[0m formatted_output \u001b[39m=\u001b[39m format_table(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2109'>2110</a>\u001b[0m pa_subtable, key, formatter\u001b[39m=\u001b[39;49mformatter, format_columns\u001b[39m=\u001b[39;49mformat_columns, output_all_columns\u001b[39m=\u001b[39;49moutput_all_columns\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2110'>2111</a>\u001b[0m )\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2111'>2112</a>\u001b[0m \u001b[39mreturn\u001b[39;00m formatted_output\n",
|
637 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:532\u001b[0m, in \u001b[0;36mformat_table\u001b[0;34m(table, key, formatter, format_columns, output_all_columns)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=529'>530</a>\u001b[0m python_formatter \u001b[39m=\u001b[39m PythonFormatter(features\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=530'>531</a>\u001b[0m \u001b[39mif\u001b[39;00m format_columns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=531'>532</a>\u001b[0m \u001b[39mreturn\u001b[39;00m formatter(pa_table, query_type\u001b[39m=\u001b[39;49mquery_type)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=532'>533</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=533'>534</a>\u001b[0m \u001b[39mif\u001b[39;00m key \u001b[39min\u001b[39;00m format_columns:\n",
|
638 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:283\u001b[0m, in \u001b[0;36mFormatter.__call__\u001b[0;34m(self, pa_table, query_type)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=280'>281</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_row(pa_table)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=281'>282</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=282'>283</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mformat_column(pa_table)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=283'>284</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mbatch\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=284'>285</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_batch(pa_table)\n",
|
639 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:316\u001b[0m, in \u001b[0;36mPythonFormatter.format_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=314'>315</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=315'>316</a>\u001b[0m column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpython_arrow_extractor()\u001b[39m.\u001b[39;49mextract_column(pa_table)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=316'>317</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdecoded:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=317'>318</a>\u001b[0m column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpython_features_decoder\u001b[39m.\u001b[39mdecode_column(column, pa_table\u001b[39m.\u001b[39mcolumn_names[\u001b[39m0\u001b[39m])\n",
|
640 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:143\u001b[0m, in \u001b[0;36mPythonArrowExtractor.extract_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=141'>142</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mextract_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=142'>143</a>\u001b[0m \u001b[39mreturn\u001b[39;00m pa_table\u001b[39m.\u001b[39;49mcolumn(\u001b[39m0\u001b[39;49m)\u001b[39m.\u001b[39;49mto_pylist()\n",
|
641 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
642 |
+
]
|
643 |
+
}
|
644 |
+
],
|
645 |
+
"source": [
|
646 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)\n",
|
647 |
+
"\n",
|
648 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
649 |
+
"model.to(device)\n",
|
650 |
+
"\n",
|
651 |
+
"model.train()\n",
|
652 |
+
"for epoch in range(EPOCHS):\n",
|
653 |
+
" for index, batch in enumerate(tqdm(train_dataloader)):\n",
|
654 |
+
" pixel_values = batch[\"pixel_values\"].to(device)\n",
|
655 |
+
" labels = batch[\"labels\"].to(device)\n",
|
656 |
+
"\n",
|
657 |
+
" optimizer.zero_grad()\n",
|
658 |
+
"\n",
|
659 |
+
" outputs = model(pixel_values=pixel_values, labels=labels)\n",
|
660 |
+
" loss, logits = outputs.loss, outputs.logits\n",
|
661 |
+
"\n",
|
662 |
+
" loss.backward()\n",
|
663 |
+
" optimizer.step()\n",
|
664 |
+
"\n",
|
665 |
+
" with torch.no_grad():\n",
|
666 |
+
" upsampled_logits = nn.functional.interpolate(\n",
|
667 |
+
" logits, size=labels.shape[-2:], mode=\"bilinear\", align_corners=False\n",
|
668 |
+
" )\n",
|
669 |
+
" predicted = upsampled_logits.argmax(dim=1)\n",
|
670 |
+
" metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())\n",
|
671 |
+
"\n",
|
672 |
+
" if index % 100 == 0:\n",
|
673 |
+
" metrics = metric.compute(num_labels=num_labels, ignore_index=255, reduce_labels=False)\n",
|
674 |
+
" print(f\"Epoch {epoch}/{EPOCHS} Batch {index}/{len(train_dataloader)} Loss {loss.item():.4f} Metrics {metrics}\")"
|
675 |
+
]
|
676 |
+
},
|
677 |
+
{
|
678 |
+
"cell_type": "code",
|
679 |
+
"execution_count": 5,
|
680 |
+
"metadata": {},
|
681 |
+
"outputs": [
|
682 |
+
{
|
683 |
+
"name": "stderr",
|
684 |
+
"output_type": "stream",
|
685 |
+
"text": [
|
686 |
+
"Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n",
|
687 |
+
"Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n"
|
688 |
+
]
|
689 |
+
},
|
690 |
+
{
|
691 |
+
"name": "stdout",
|
692 |
+
"output_type": "stream",
|
693 |
+
"text": [
|
694 |
+
"1000\n"
|
695 |
+
]
|
696 |
+
}
|
697 |
+
],
|
698 |
+
"source": [
|
699 |
+
"import numpy as np\n",
|
700 |
+
"\n",
|
701 |
+
"tokenizer = SegformerFeatureExtractor(reduce_labels=True)\n",
|
702 |
+
"\n",
|
703 |
+
"dataset = load_dataset(HUB_DIR, split=\"train\")\n",
|
704 |
+
"length = len(dataset)\n",
|
705 |
+
"print(length)\n",
|
706 |
+
"\n",
|
707 |
+
"encoded_dataset = tokenizer(\n",
|
708 |
+
" images=dataset[\"pixel_values\"], segmentation_maps=dataset[\"label\"], return_tensors=\"pt\"\n",
|
709 |
+
")"
|
710 |
+
]
|
711 |
+
},
|
712 |
+
{
|
713 |
+
"cell_type": "code",
|
714 |
+
"execution_count": 9,
|
715 |
+
"metadata": {},
|
716 |
+
"outputs": [],
|
717 |
+
"source": [
|
718 |
+
"pixel_values = encoded_dataset[\"pixel_values\"]\n",
|
719 |
+
"labels = encoded_dataset[\"labels\"]"
|
720 |
+
]
|
721 |
+
},
|
722 |
+
{
|
723 |
+
"cell_type": "code",
|
724 |
+
"execution_count": 10,
|
725 |
+
"metadata": {},
|
726 |
+
"outputs": [
|
727 |
+
{
|
728 |
+
"data": {
|
729 |
+
"text/plain": [
|
730 |
+
"torch.Tensor"
|
731 |
+
]
|
732 |
+
},
|
733 |
+
"execution_count": 10,
|
734 |
+
"metadata": {},
|
735 |
+
"output_type": "execute_result"
|
736 |
+
}
|
737 |
+
],
|
738 |
+
"source": [
|
739 |
+
"type(pixel_values)"
|
740 |
+
]
|
741 |
+
},
|
742 |
+
{
|
743 |
+
"cell_type": "code",
|
744 |
+
"execution_count": 25,
|
745 |
+
"metadata": {},
|
746 |
+
"outputs": [],
|
747 |
+
"source": [
|
748 |
+
"from torch.utils.data import DataLoader, Dataset, random_split, Subset\n",
|
749 |
+
"\n",
|
750 |
+
"class SegmentationDataset(Dataset):\n",
|
751 |
+
" def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):\n",
|
752 |
+
" self.pixel_values = pixel_values\n",
|
753 |
+
" self.labels = labels\n",
|
754 |
+
" assert pixel_values.shape[0] == labels.shape[0]\n",
|
755 |
+
" self.length = pixel_values.shape[0]\n",
|
756 |
+
" print(f\"Created dataset with {self.length} samples\")\n",
|
757 |
+
" \n",
|
758 |
+
"\n",
|
759 |
+
" def __len__(self):\n",
|
760 |
+
" return self.length\n",
|
761 |
+
"\n",
|
762 |
+
"\n",
|
763 |
+
" def __getitem__(self, index):\n",
|
764 |
+
" image = self.pixel_values[index]\n",
|
765 |
+
" label = self.labels[index]\n",
|
766 |
+
"\n",
|
767 |
+
" encoded_inputs = BatchFeature({\"pixel_values\": image, \"labels\": label})\n",
|
768 |
+
"\n",
|
769 |
+
" return encoded_inputs"
|
770 |
+
]
|
771 |
+
},
|
772 |
+
{
|
773 |
+
"cell_type": "code",
|
774 |
+
"execution_count": 26,
|
775 |
+
"metadata": {},
|
776 |
+
"outputs": [
|
777 |
+
{
|
778 |
+
"name": "stdout",
|
779 |
+
"output_type": "stream",
|
780 |
+
"text": [
|
781 |
+
"Created dataset with 1000 samples\n"
|
782 |
+
]
|
783 |
+
}
|
784 |
+
],
|
785 |
+
"source": [
|
786 |
+
"segmentation_dataset = SegmentationDataset(pixel_values, labels)"
|
787 |
+
]
|
788 |
+
},
|
789 |
+
{
|
790 |
+
"cell_type": "code",
|
791 |
+
"execution_count": 27,
|
792 |
+
"metadata": {},
|
793 |
+
"outputs": [],
|
794 |
+
"source": [
|
795 |
+
"test = segmentation_dataset[0]"
|
796 |
+
]
|
797 |
+
},
|
798 |
+
{
|
799 |
+
"cell_type": "code",
|
800 |
+
"execution_count": 28,
|
801 |
+
"metadata": {},
|
802 |
+
"outputs": [
|
803 |
+
{
|
804 |
+
"data": {
|
805 |
+
"text/plain": [
|
806 |
+
"{'pixel_values': tensor([[[ 0.0912, -0.1828, -0.1143, ..., -0.5253, -0.5424, -0.6623],\n",
|
807 |
+
" [-0.0116, -0.2342, -0.1486, ..., -0.5424, -0.6109, -0.7137],\n",
|
808 |
+
" [ 0.0398, -0.1828, -0.1314, ..., -0.5082, -0.6281, -0.7650],\n",
|
809 |
+
" ...,\n",
|
810 |
+
" [ 1.1529, 1.3927, 1.0331, ..., 0.4166, 0.3481, 0.3309],\n",
|
811 |
+
" [ 0.9474, 1.1358, 1.3070, ..., 0.5022, 0.3652, 0.3994],\n",
|
812 |
+
" [ 0.6049, 1.3413, 1.1358, ..., 1.2728, 0.6563, 0.8104]],\n",
|
813 |
+
"\n",
|
814 |
+
" [[ 0.3102, -0.2150, -0.3200, ..., -0.3901, -0.4426, -0.5651],\n",
|
815 |
+
" [ 0.2227, -0.2850, -0.3725, ..., -0.4076, -0.5126, -0.6176],\n",
|
816 |
+
" [ 0.2577, -0.2325, -0.3550, ..., -0.3725, -0.5301, -0.6702],\n",
|
817 |
+
" ...,\n",
|
818 |
+
" [ 1.1506, 1.3957, 1.0280, ..., 0.4678, 0.3978, 0.3803],\n",
|
819 |
+
" [ 0.9405, 1.1331, 1.3081, ..., 0.5553, 0.4153, 0.4503],\n",
|
820 |
+
" [ 0.5903, 1.3431, 1.1331, ..., 1.3431, 0.7129, 0.8704]],\n",
|
821 |
+
"\n",
|
822 |
+
" [[ 0.4788, -0.0267, -0.1312, ..., 0.0431, -0.0441, -0.2010],\n",
|
823 |
+
" [ 0.3916, -0.0790, -0.1835, ..., 0.0256, -0.1138, -0.2532],\n",
|
824 |
+
" [ 0.4265, -0.0267, -0.1661, ..., 0.0605, -0.1312, -0.3055],\n",
|
825 |
+
" ...,\n",
|
826 |
+
" [ 1.2805, 1.5245, 1.1585, ..., 0.6356, 0.5659, 0.5485],\n",
|
827 |
+
" [ 1.0714, 1.2631, 1.4374, ..., 0.7228, 0.5834, 0.6182],\n",
|
828 |
+
" [ 0.7228, 1.4722, 1.2631, ..., 1.5071, 0.8797, 1.0365]]]), 'labels': tensor([[17, 17, 17, ..., 17, 17, 17],\n",
|
829 |
+
" [17, 17, 17, ..., 17, 17, 17],\n",
|
830 |
+
" [17, 17, 17, ..., 17, 17, 17],\n",
|
831 |
+
" ...,\n",
|
832 |
+
" [ 1, 1, 1, ..., 1, 1, 1],\n",
|
833 |
+
" [ 1, 1, 1, ..., 1, 1, 1],\n",
|
834 |
+
" [ 1, 1, 1, ..., 1, 1, 1]])}"
|
835 |
+
]
|
836 |
+
},
|
837 |
+
"execution_count": 28,
|
838 |
+
"metadata": {},
|
839 |
+
"output_type": "execute_result"
|
840 |
+
}
|
841 |
+
],
|
842 |
+
"source": [
|
843 |
+
"test"
|
844 |
+
]
|
845 |
+
},
|
846 |
+
{
|
847 |
+
"cell_type": "code",
|
848 |
+
"execution_count": 20,
|
849 |
+
"metadata": {},
|
850 |
+
"outputs": [
|
851 |
+
{
|
852 |
+
"data": {
|
853 |
+
"text/plain": [
|
854 |
+
"tensor([ 0, 1, 2, 4, 6, 9, 17, 18, 19, 23, 24, 25, 27, 31, 32])"
|
855 |
+
]
|
856 |
+
},
|
857 |
+
"execution_count": 20,
|
858 |
+
"metadata": {},
|
859 |
+
"output_type": "execute_result"
|
860 |
+
}
|
861 |
+
],
|
862 |
+
"source": [
|
863 |
+
"segmentation_dataset[0][1].squeeze().unique()"
|
864 |
+
]
|
865 |
+
},
|
866 |
+
{
|
867 |
+
"cell_type": "code",
|
868 |
+
"execution_count": 22,
|
869 |
+
"metadata": {},
|
870 |
+
"outputs": [],
|
871 |
+
"source": [
|
872 |
+
"indices = np.arange(length)\n",
|
873 |
+
"train_indices, val_indices = random_split(indices, [int(length * 0.8), int(length * 0.2)])\n",
|
874 |
+
"\n",
|
875 |
+
"train_dataset = SegmentationDataset(encoded_dataset, train_indices)\n",
|
876 |
+
"val_dataset = SegmentationDataset(encoded_dataset, val_indices)"
|
877 |
+
]
|
878 |
+
},
|
879 |
+
{
|
880 |
+
"cell_type": "code",
|
881 |
+
"execution_count": 26,
|
882 |
+
"metadata": {},
|
883 |
+
"outputs": [
|
884 |
+
{
|
885 |
+
"ename": "TypeError",
|
886 |
+
"evalue": "list indices must be integers or slices, not str",
|
887 |
+
"output_type": "error",
|
888 |
+
"traceback": [
|
889 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
890 |
+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
891 |
+
"\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 14'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000018?line=0'>1</a>\u001b[0m train_dataset[\u001b[39m0\u001b[39;49m]\n",
|
892 |
+
"\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=11'>12</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=12'>13</a>\u001b[0m image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=13'>14</a>\u001b[0m label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=15'>16</a>\u001b[0m \u001b[39mreturn\u001b[39;00m image, label\n",
|
893 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
|
894 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
|
895 |
+
"\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str"
|
896 |
+
]
|
897 |
+
}
|
898 |
+
],
|
899 |
+
"source": [
|
900 |
+
"train_dataset[0]"
|
901 |
+
]
|
902 |
+
},
|
903 |
+
{
|
904 |
+
"cell_type": "code",
|
905 |
+
"execution_count": 23,
|
906 |
+
"metadata": {},
|
907 |
+
"outputs": [],
|
908 |
+
"source": [
|
909 |
+
"train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n",
|
910 |
+
"valid_dataloader = DataLoader(val_dataset, batch_size=2)"
|
911 |
+
]
|
912 |
+
},
|
913 |
+
{
|
914 |
+
"cell_type": "code",
|
915 |
+
"execution_count": 24,
|
916 |
+
"metadata": {},
|
917 |
+
"outputs": [
|
918 |
+
{
|
919 |
+
"ename": "TypeError",
|
920 |
+
"evalue": "list indices must be integers or slices, not str",
|
921 |
+
"output_type": "error",
|
922 |
+
"traceback": [
|
923 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
924 |
+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
925 |
+
"\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 15'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=0'>1</a>\u001b[0m batch \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(\u001b[39miter\u001b[39;49m(train_dataloader))\n",
|
926 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:530\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=527'>528</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=528'>529</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset()\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=529'>530</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=530'>531</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=531'>532</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=532'>533</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=533'>534</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
|
927 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:570\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=567'>568</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=568'>569</a>\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=569'>570</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=570'>571</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=571'>572</a>\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data)\n",
|
928 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=46'>47</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=47'>48</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=48'>49</a>\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=49'>50</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=50'>51</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
|
929 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=46'>47</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=47'>48</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=48'>49</a>\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=49'>50</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=50'>51</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
|
930 |
+
"\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=11'>12</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=12'>13</a>\u001b[0m image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=13'>14</a>\u001b[0m label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=15'>16</a>\u001b[0m \u001b[39mreturn\u001b[39;00m image, label\n",
|
931 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
|
932 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
|
933 |
+
"\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str"
|
934 |
+
]
|
935 |
+
}
|
936 |
+
],
|
937 |
+
"source": [
|
938 |
+
"batch = next(iter(train_dataloader))"
|
939 |
+
]
|
940 |
+
},
|
941 |
+
{
|
942 |
+
"cell_type": "code",
|
943 |
+
"execution_count": null,
|
944 |
+
"metadata": {},
|
945 |
+
"outputs": [],
|
946 |
+
"source": []
|
947 |
+
}
|
948 |
+
],
|
949 |
+
"metadata": {
|
950 |
+
"interpreter": {
|
951 |
+
"hash": "19a4294f02aad727716e8ed0f765e04171ea0ecb4c129d0b0eebf53be4c3a095"
|
952 |
+
},
|
953 |
+
"kernelspec": {
|
954 |
+
"display_name": "Python 3.8.13 ('segformer')",
|
955 |
+
"language": "python",
|
956 |
+
"name": "python3"
|
957 |
+
},
|
958 |
+
"language_info": {
|
959 |
+
"codemirror_mode": {
|
960 |
+
"name": "ipython",
|
961 |
+
"version": 3
|
962 |
+
},
|
963 |
+
"file_extension": ".py",
|
964 |
+
"mimetype": "text/x-python",
|
965 |
+
"name": "python",
|
966 |
+
"nbconvert_exporter": "python",
|
967 |
+
"pygments_lexer": "ipython3",
|
968 |
+
"version": "3.8.13"
|
969 |
+
},
|
970 |
+
"orig_nbformat": 4
|
971 |
+
},
|
972 |
+
"nbformat": 4,
|
973 |
+
"nbformat_minor": 2
|
974 |
+
}
|