create dataset and dataloader
Browse files- dataloader.py +79 -0
dataloader.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from datasets import load_dataset
|
6 |
+
from torch.utils.data import DataLoader, Dataset, random_split, Subset
|
7 |
+
from transformers import SegformerFeatureExtractor, BatchFeature
|
8 |
+
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
|
12 |
+
class SegmentationDataset(Dataset):
|
13 |
+
"""Image Segmentation Dataset"""
|
14 |
+
def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):
|
15 |
+
"""
|
16 |
+
Dataset for image segmentation.
|
17 |
+
|
18 |
+
Parameters
|
19 |
+
----------
|
20 |
+
pixel_values : torch.Tensor
|
21 |
+
Tensor of shape (N, H, W) containing the pixel values of the images.
|
22 |
+
labels : torch.Tensor
|
23 |
+
Tensor of shape (H, W) containing the labels of the images.
|
24 |
+
"""
|
25 |
+
self.pixel_values = pixel_values
|
26 |
+
self.labels = labels
|
27 |
+
assert pixel_values.shape[0] == labels.shape[0]
|
28 |
+
self.length = pixel_values.shape[0]
|
29 |
+
print(f"Created dataset with {self.length} samples")
|
30 |
+
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return self.length
|
34 |
+
|
35 |
+
|
36 |
+
def __getitem__(self, index):
|
37 |
+
image = self.pixel_values[index]
|
38 |
+
label = self.labels[index]
|
39 |
+
|
40 |
+
encoded_inputs = BatchFeature({"pixel_values": image, "labels": label})
|
41 |
+
|
42 |
+
return encoded_inputs
|
43 |
+
|
44 |
+
|
45 |
+
class SidewalkSegmentationDataLoader(pl.LightningDataModule):
|
46 |
+
def __init__(
|
47 |
+
self, hub_dir: str, batch_size: int, split: Optional[str] = None,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.hub_dir = hub_dir
|
51 |
+
self.batch_size = batch_size
|
52 |
+
self.tokenizer = SegformerFeatureExtractor(reduce_labels=True)
|
53 |
+
self.dataset = load_dataset(self.hub_dir, split=split)
|
54 |
+
self.len = len(self.dataset)
|
55 |
+
|
56 |
+
|
57 |
+
def tokenize_data(self, *args, **kwargs):
|
58 |
+
return self.tokenizer(*args, **kwargs)
|
59 |
+
|
60 |
+
|
61 |
+
def setup(self, stage: str = None):
|
62 |
+
encoded_dataset = self.tokenize_data(
|
63 |
+
images=self.dataset["pixel_values"], segmentation_maps=self.dataset["label"], return_tensors="pt"
|
64 |
+
)
|
65 |
+
dataset = SegmentationDataset(encoded_dataset["pixel_values"], encoded_dataset["labels"])
|
66 |
+
|
67 |
+
indices = np.arange(self.len)
|
68 |
+
train_indices, val_indices = random_split(indices, [int(self.len * 0.8), int(self.len * 0.2)])
|
69 |
+
|
70 |
+
self.train_dataset = Subset(dataset, train_indices)
|
71 |
+
self.val_dataset = Subset(dataset, val_indices)
|
72 |
+
|
73 |
+
|
74 |
+
def train_dataloader(self):
|
75 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12)
|
76 |
+
|
77 |
+
|
78 |
+
def val_dataloader(self):
|
79 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12)
|