import numpy as np import pytorch_lightning as pl import torch from datasets import load_dataset from torch.utils.data import DataLoader, Dataset, random_split, Subset from transformers import SegformerFeatureExtractor, BatchFeature from typing import Optional class SegmentationDataset(Dataset): """Image Segmentation Dataset""" def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor): """ Dataset for image segmentation. Parameters ---------- pixel_values : torch.Tensor Tensor of shape (N, H, W) containing the pixel values of the images. labels : torch.Tensor Tensor of shape (H, W) containing the labels of the images. """ self.pixel_values = pixel_values self.labels = labels assert pixel_values.shape[0] == labels.shape[0] self.length = pixel_values.shape[0] print(f"Created dataset with {self.length} samples") def __len__(self): return self.length def __getitem__(self, index): image = self.pixel_values[index] label = self.labels[index] encoded_inputs = BatchFeature({"pixel_values": image, "labels": label}) return encoded_inputs class SidewalkSegmentationDataLoader(pl.LightningDataModule): def __init__( self, hub_dir: str, batch_size: int, split: Optional[str] = None, ): super().__init__() self.hub_dir = hub_dir self.batch_size = batch_size self.tokenizer = SegformerFeatureExtractor(reduce_labels=True) self.dataset = load_dataset(self.hub_dir, split=split) self.len = len(self.dataset) def tokenize_data(self, *args, **kwargs): return self.tokenizer(*args, **kwargs) def setup(self, stage: str = None): encoded_dataset = self.tokenize_data( images=self.dataset["pixel_values"], segmentation_maps=self.dataset["label"], return_tensors="pt" ) dataset = SegmentationDataset(encoded_dataset["pixel_values"], encoded_dataset["labels"]) indices = np.arange(self.len) train_indices, val_indices = random_split(indices, [int(self.len * 0.8), int(self.len * 0.2)]) self.train_dataset = Subset(dataset, train_indices) self.val_dataset = Subset(dataset, val_indices) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12)