{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "import pytorch_lightning as pl\n", "import torch\n", "import torchmetrics\n", "\n", "from datasets import load_dataset, load_metric\n", "\n", "from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation\n", "\n", "from torch import nn\n", "from torch.utils.data import DataLoader, Dataset, random_split\n", "\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class SemanticSegmentationDataset(Dataset):\n", " \"\"\"Image segmentation datasets.\"\"\"\n", "\n", " def __init__(\n", " self, \n", " dataset: torch.utils.data.dataset.Subset, \n", " feature_extractor = SegformerFeatureExtractor(reduce_labels=True),\n", " ):\n", " \"\"\"\n", " Initialize the dataset with the given feature extractor and split.\n", "\n", " Parameters\n", " ----------\n", " hub_dir : Dataset\n", " The dataset to use.\n", " feature_extractor : FeatureExtractor, optional\n", " The feature extractor to use. The default is SegformerFeatureExtractor.\n", " \"\"\"\n", " self.dataset = dataset\n", " self.feature_extractor = feature_extractor\n", " self.length = len(self.dataset)\n", " print(f\"Loaded {self.length} samples.\")\n", "\n", "\n", " def __len__(self):\n", " \"\"\"Return the number of samples in the dataset.\"\"\"\n", " return self.length\n", "\n", "\n", " def __getitem__(self, index: int):\n", " \"\"\"Get the sample at the given index.\"\"\"\n", " image = self.dataset[index][\"pixel_values\"]\n", " label = self.dataset[index][\"label\"]\n", "\n", " encoded_inputs = self.feature_extractor(image, label, return_tensors=\"pt\")\n", "\n", " for k, v in encoded_inputs.items():\n", " encoded_inputs[k].squeeze_() # remove batch dimension\n", "\n", " return encoded_inputs" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 32\n", "HUB_DIR = \"segments/sidewalk-semantic\"\n", "EPOCHS = 200" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n", "Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loaded 800 samples.\n", "Loaded 200 samples.\n" ] } ], "source": [ "dataset = load_dataset(HUB_DIR, split=\"train\")\n", "\n", "train_dataset, val_dataset = random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])\n", "train_dataset = SemanticSegmentationDataset(train_dataset)\n", "val_dataset = SemanticSegmentationDataset(val_dataset)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pixel_values torch.Size([32, 3, 512, 512])\n", "labels torch.Size([32, 512, 512])\n" ] } ], "source": [ "batch = next(iter(train_dataloader))\n", "\n", "for k, v in batch.items():\n", " print(k, v.shape)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# class SidewalkSegmentationModel(pl.LightningModule):\n", "# def __init__(self, num_classes: int, learning_rate: float = 6e-5):\n", "# super().__init__()\n", "# self.model = SegformerForSemanticSegmentation.from_pretrained(\n", "# \"nvidia/mit-b0\", num_labels=num_classes, id2label=id2label, label2id=label2id,\n", "# )\n", "# self.learning_rate = learning_rate\n", "# self.metric = load_metric(\"mean_iou\")\n", "\n", " \n", "# def forward(self, *args, **kwargs):\n", "# return self.model(*args, **kwargs)\n", "\n", " \n", "# def training_step(self, batch, batch_idx):\n", "# pixel_values = batch[\"pixel_values\"]\n", "# labels = batch[\"labels\"]\n", "\n", "# outputs = self(pixel_values=pixel_values, labels=labels)\n", "# loss, logits = outputs.loss, outputs.logits\n", "\n", " \n", "# def configure_optimizers(self) -> torch.optim.AdamW:\n", "# \"\"\"\n", "# Configure the optimizer.\n", "# Returns\n", "# -------\n", "# torch.optim.AdamW\n", "# Optimizer for the model\n", "# \"\"\"\n", "# return torch.optim.AdamW(model.parameters(), lr=self.learning_rate)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']\n", "- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.classifier.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.1.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_c.2.proj.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "id2label_file = json.load(open(\"id2label.json\", \"r\"))\n", "id2label = {int(k): v for k, v in id2label_file.items()}\n", "print(id2label)\n", "label2id = {v: k for k, v in id2label_file.items()}\n", "num_labels = len(id2label)\n", "\n", "model = SegformerForSemanticSegmentation.from_pretrained(\n", " \"nvidia/mit-b0\", num_labels=num_labels, id2label=id2label, label2id=label2id,\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "metric = load_metric(\"mean_iou\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "091f8e1587f64625a4bbbf04f13a840e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/25 [00:00\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m metric\u001b[39m.\u001b[39madd_batch(predictions\u001b[39m=\u001b[39mpredicted\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy(), references\u001b[39m=\u001b[39mlabels\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy())\n\u001b[1;32m 27\u001b[0m \u001b[39mif\u001b[39;00m index \u001b[39m%\u001b[39m \u001b[39m100\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m---> 28\u001b[0m metrics \u001b[39m=\u001b[39m metric\u001b[39m.\u001b[39;49mcompute(num_labels\u001b[39m=\u001b[39;49mnum_labels, ignore_index\u001b[39m=\u001b[39;49m\u001b[39m255\u001b[39;49m, reduce_labels\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m 29\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEpoch \u001b[39m\u001b[39m{\u001b[39;00mepoch\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00mEPOCHS\u001b[39m}\u001b[39;00m\u001b[39m Batch \u001b[39m\u001b[39m{\u001b[39;00mindex\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mlen\u001b[39m(train_dataloader)\u001b[39m}\u001b[39;00m\u001b[39m Loss \u001b[39m\u001b[39m{\u001b[39;00mloss\u001b[39m.\u001b[39mitem()\u001b[39m:\u001b[39;00m\u001b[39m.4f\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m Metrics \u001b[39m\u001b[39m{\u001b[39;00mmetrics\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36mMetric.compute\u001b[0;34m(self, predictions, references, **kwargs)\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m 426\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> 428\u001b[0m inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m 429\u001b[0m \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m 430\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m 426\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> 428\u001b[0m inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m 429\u001b[0m \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m 430\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2124\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 2122\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, key): \u001b[39m# noqa: F811\u001b[39;00m\n\u001b[1;32m 2123\u001b[0m \u001b[39m\"\"\"Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).\"\"\"\u001b[39;00m\n\u001b[0;32m-> 2124\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_getitem(\n\u001b[1;32m 2125\u001b[0m key,\n\u001b[1;32m 2126\u001b[0m )\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2109\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, decoded, **kwargs)\u001b[0m\n\u001b[1;32m 2107\u001b[0m formatter \u001b[39m=\u001b[39m get_formatter(format_type, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures, decoded\u001b[39m=\u001b[39mdecoded, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mformat_kwargs)\n\u001b[1;32m 2108\u001b[0m pa_subtable \u001b[39m=\u001b[39m query_table(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_data, key, indices\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m)\n\u001b[0;32m-> 2109\u001b[0m formatted_output \u001b[39m=\u001b[39m format_table(\n\u001b[1;32m 2110\u001b[0m pa_subtable, key, formatter\u001b[39m=\u001b[39;49mformatter, format_columns\u001b[39m=\u001b[39;49mformat_columns, output_all_columns\u001b[39m=\u001b[39;49moutput_all_columns\n\u001b[1;32m 2111\u001b[0m )\n\u001b[1;32m 2112\u001b[0m \u001b[39mreturn\u001b[39;00m formatted_output\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:532\u001b[0m, in \u001b[0;36mformat_table\u001b[0;34m(table, key, formatter, format_columns, output_all_columns)\u001b[0m\n\u001b[1;32m 530\u001b[0m python_formatter \u001b[39m=\u001b[39m PythonFormatter(features\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m)\n\u001b[1;32m 531\u001b[0m \u001b[39mif\u001b[39;00m format_columns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 532\u001b[0m \u001b[39mreturn\u001b[39;00m formatter(pa_table, query_type\u001b[39m=\u001b[39;49mquery_type)\n\u001b[1;32m 533\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 534\u001b[0m \u001b[39mif\u001b[39;00m key \u001b[39min\u001b[39;00m format_columns:\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:283\u001b[0m, in \u001b[0;36mFormatter.__call__\u001b[0;34m(self, pa_table, query_type)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_row(pa_table)\n\u001b[1;32m 282\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[0;32m--> 283\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mformat_column(pa_table)\n\u001b[1;32m 284\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mbatch\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 285\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_batch(pa_table)\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:316\u001b[0m, in \u001b[0;36mPythonFormatter.format_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> 316\u001b[0m column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpython_arrow_extractor()\u001b[39m.\u001b[39;49mextract_column(pa_table)\n\u001b[1;32m 317\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdecoded:\n\u001b[1;32m 318\u001b[0m column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpython_features_decoder\u001b[39m.\u001b[39mdecode_column(column, pa_table\u001b[39m.\u001b[39mcolumn_names[\u001b[39m0\u001b[39m])\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:143\u001b[0m, in \u001b[0;36mPythonArrowExtractor.extract_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mextract_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> 143\u001b[0m \u001b[39mreturn\u001b[39;00m pa_table\u001b[39m.\u001b[39;49mcolumn(\u001b[39m0\u001b[39;49m)\u001b[39m.\u001b[39;49mto_pylist()\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)\n", "\n", "model.train()\n", "for epoch in range(EPOCHS):\n", " for index, batch in enumerate(tqdm(train_dataloader)):\n", " pixel_values = batch[\"pixel_values\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", " optimizer.zero_grad()\n", "\n", " outputs = model(pixel_values=pixel_values, labels=labels)\n", " loss, logits = outputs.loss, outputs.logits\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " with torch.no_grad():\n", " upsampled_logits = nn.functional.interpolate(\n", " logits, size=labels.shape[-2:], mode=\"bilinear\", align_corners=False\n", " )\n", " predicted = upsampled_logits.argmax(dim=1)\n", " metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())\n", "\n", " if index % 100 == 0:\n", " metrics = metric.compute(num_labels=num_labels, ignore_index=255, reduce_labels=False)\n", " print(f\"Epoch {epoch}/{EPOCHS} Batch {index}/{len(train_dataloader)} Loss {loss.item():.4f} Metrics {metrics}\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n", "Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1000\n" ] } ], "source": [ "import numpy as np\n", "\n", "tokenizer = SegformerFeatureExtractor(reduce_labels=True)\n", "\n", "dataset = load_dataset(HUB_DIR, split=\"train\")\n", "length = len(dataset)\n", "print(length)\n", "\n", "encoded_dataset = tokenizer(\n", " images=dataset[\"pixel_values\"], segmentation_maps=dataset[\"label\"], return_tensors=\"pt\"\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "pixel_values = encoded_dataset[\"pixel_values\"]\n", "labels = encoded_dataset[\"labels\"]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(pixel_values)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader, Dataset, random_split, Subset\n", "\n", "class SegmentationDataset(Dataset):\n", " def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):\n", " self.pixel_values = pixel_values\n", " self.labels = labels\n", " assert pixel_values.shape[0] == labels.shape[0]\n", " self.length = pixel_values.shape[0]\n", " print(f\"Created dataset with {self.length} samples\")\n", " \n", "\n", " def __len__(self):\n", " return self.length\n", "\n", "\n", " def __getitem__(self, index):\n", " image = self.pixel_values[index]\n", " label = self.labels[index]\n", "\n", " encoded_inputs = BatchFeature({\"pixel_values\": image, \"labels\": label})\n", "\n", " return encoded_inputs" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Created dataset with 1000 samples\n" ] } ], "source": [ "segmentation_dataset = SegmentationDataset(pixel_values, labels)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "test = segmentation_dataset[0]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'pixel_values': tensor([[[ 0.0912, -0.1828, -0.1143, ..., -0.5253, -0.5424, -0.6623],\n", " [-0.0116, -0.2342, -0.1486, ..., -0.5424, -0.6109, -0.7137],\n", " [ 0.0398, -0.1828, -0.1314, ..., -0.5082, -0.6281, -0.7650],\n", " ...,\n", " [ 1.1529, 1.3927, 1.0331, ..., 0.4166, 0.3481, 0.3309],\n", " [ 0.9474, 1.1358, 1.3070, ..., 0.5022, 0.3652, 0.3994],\n", " [ 0.6049, 1.3413, 1.1358, ..., 1.2728, 0.6563, 0.8104]],\n", "\n", " [[ 0.3102, -0.2150, -0.3200, ..., -0.3901, -0.4426, -0.5651],\n", " [ 0.2227, -0.2850, -0.3725, ..., -0.4076, -0.5126, -0.6176],\n", " [ 0.2577, -0.2325, -0.3550, ..., -0.3725, -0.5301, -0.6702],\n", " ...,\n", " [ 1.1506, 1.3957, 1.0280, ..., 0.4678, 0.3978, 0.3803],\n", " [ 0.9405, 1.1331, 1.3081, ..., 0.5553, 0.4153, 0.4503],\n", " [ 0.5903, 1.3431, 1.1331, ..., 1.3431, 0.7129, 0.8704]],\n", "\n", " [[ 0.4788, -0.0267, -0.1312, ..., 0.0431, -0.0441, -0.2010],\n", " [ 0.3916, -0.0790, -0.1835, ..., 0.0256, -0.1138, -0.2532],\n", " [ 0.4265, -0.0267, -0.1661, ..., 0.0605, -0.1312, -0.3055],\n", " ...,\n", " [ 1.2805, 1.5245, 1.1585, ..., 0.6356, 0.5659, 0.5485],\n", " [ 1.0714, 1.2631, 1.4374, ..., 0.7228, 0.5834, 0.6182],\n", " [ 0.7228, 1.4722, 1.2631, ..., 1.5071, 0.8797, 1.0365]]]), 'labels': tensor([[17, 17, 17, ..., 17, 17, 17],\n", " [17, 17, 17, ..., 17, 17, 17],\n", " [17, 17, 17, ..., 17, 17, 17],\n", " ...,\n", " [ 1, 1, 1, ..., 1, 1, 1],\n", " [ 1, 1, 1, ..., 1, 1, 1],\n", " [ 1, 1, 1, ..., 1, 1, 1]])}" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0, 1, 2, 4, 6, 9, 17, 18, 19, 23, 24, 25, 27, 31, 32])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "segmentation_dataset[0][1].squeeze().unique()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "indices = np.arange(length)\n", "train_indices, val_indices = random_split(indices, [int(length * 0.8), int(length * 0.2)])\n", "\n", "train_dataset = SegmentationDataset(encoded_dataset, train_indices)\n", "val_dataset = SegmentationDataset(encoded_dataset, val_indices)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "list indices must be integers or slices, not str", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 14'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_dataset[\u001b[39m0\u001b[39;49m]\n", "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> 13\u001b[0m image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m 14\u001b[0m label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m image, label\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m 470\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> 471\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m 470\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> 471\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n", "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str" ] } ], "source": [ "train_dataset[0]" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n", "valid_dataloader = DataLoader(val_dataset, batch_size=2)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "list indices must be integers or slices, not str", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 15'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m batch \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(\u001b[39miter\u001b[39;49m(train_dataloader))\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:530\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 528\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 529\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset()\n\u001b[0;32m--> 530\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 531\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 532\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 533\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 534\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:570\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 569\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 570\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 571\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 572\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data)\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m 48\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> 49\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m 48\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> 49\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> 13\u001b[0m image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m 14\u001b[0m label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m image, label\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m 470\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> 471\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m 470\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> 471\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n", "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str" ] } ], "source": [ "batch = next(iter(train_dataloader))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/chainyo/code/segformer-sidewalk/. is already a clone of https://huggingface.co/ChainYo/segformer-sidewalk. Make sure you pull the latest changes with `repo.git_pull()`.\n", "remote: Enforcing permissions... \n", "remote: Allowed refs: all \n", "To https://huggingface.co/ChainYo/segformer-sidewalk\n", " 5d5f276..56db83f main -> main\n", "\n" ] }, { "ename": "TypeError", "evalue": "__init__() got an unexpected keyword argument 'num_labels'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 23'\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m config \u001b[39m=\u001b[39m AutoConfig\u001b[39m.\u001b[39mfrom_pretrained(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mnvidia/mit-b0\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 9\u001b[0m config\u001b[39m.\u001b[39mpush_to_hub(\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m, repo_url\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/ChainYo/segformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 11\u001b[0m model \u001b[39m=\u001b[39m SegformerForSemanticSegmentation\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m 12\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39m/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt\u001b[39;49m\u001b[39m\"\u001b[39;49m, \n\u001b[1;32m 13\u001b[0m num_labels\u001b[39m=\u001b[39;49mnum_labels, \n\u001b[1;32m 14\u001b[0m id2label\u001b[39m=\u001b[39;49mid2label, \n\u001b[1;32m 15\u001b[0m label2id\u001b[39m=\u001b[39;49mid2label,\n\u001b[1;32m 16\u001b[0m config\u001b[39m=\u001b[39;49mconfig,\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 18\u001b[0m model\u001b[39m.\u001b[39mpush_to_hub(\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m, repo_url\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/ChainYo/segformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m)\n", "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py:2024\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 2022\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 2023\u001b[0m \u001b[39mwith\u001b[39;00m no_init_weights(_enable\u001b[39m=\u001b[39m_fast_init):\n\u001b[0;32m-> 2024\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m(config, \u001b[39m*\u001b[39;49mmodel_args, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_kwargs)\n\u001b[1;32m 2026\u001b[0m \u001b[39mif\u001b[39;00m from_tf:\n\u001b[1;32m 2027\u001b[0m \u001b[39mif\u001b[39;00m resolved_archive_file\u001b[39m.\u001b[39mendswith(\u001b[39m\"\u001b[39m\u001b[39m.index\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 2028\u001b[0m \u001b[39m# Load from a TensorFlow 1.X checkpoint - provided by original authors\u001b[39;00m\n", "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'num_labels'" ] } ], "source": [ "import json\n", "from transformers import AutoConfig\n", "\n", "id2label_file = json.load(open(\"id2label.json\", \"r\"))\n", "id2label = {int(k): v for k, v in id2label_file.items()}\n", "num_labels = len(id2label)\n", "\n", "config = AutoConfig.from_pretrained(f\"nvidia/mit-b0\")\n", "config.num_labels = num_labels\n", "config.id2label = id2label\n", "config.label2id = {v: k for k, v in id2label_file.items()}\n", "config.push_to_hub(\".\", repo_url=\"https://huggingface.co/ChainYo/segformer-sidewalk\")\n", "\n", "model = SegformerForSemanticSegmentation.from_pretrained(\n", " \"/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt\", \n", " config=config,\n", ")\n", "model.push_to_hub(\".\", repo_url=\"https://huggingface.co/ChainYo/segformer-sidewalk\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "19a4294f02aad727716e8ed0f765e04171ea0ecb4c129d0b0eebf53be4c3a095" }, "kernelspec": { "display_name": "Python 3.8.13 ('segformer')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }