chainyo commited on
Commit
3110ea7
1 Parent(s): ff6340b

create dataset and dataloader

Browse files
Files changed (1) hide show
  1. dataloader.py +79 -0
dataloader.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytorch_lightning as pl
3
+ import torch
4
+
5
+ from datasets import load_dataset
6
+ from torch.utils.data import DataLoader, Dataset, random_split, Subset
7
+ from transformers import SegformerFeatureExtractor, BatchFeature
8
+
9
+ from typing import Optional
10
+
11
+
12
+ class SegmentationDataset(Dataset):
13
+ """Image Segmentation Dataset"""
14
+ def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):
15
+ """
16
+ Dataset for image segmentation.
17
+
18
+ Parameters
19
+ ----------
20
+ pixel_values : torch.Tensor
21
+ Tensor of shape (N, H, W) containing the pixel values of the images.
22
+ labels : torch.Tensor
23
+ Tensor of shape (H, W) containing the labels of the images.
24
+ """
25
+ self.pixel_values = pixel_values
26
+ self.labels = labels
27
+ assert pixel_values.shape[0] == labels.shape[0]
28
+ self.length = pixel_values.shape[0]
29
+ print(f"Created dataset with {self.length} samples")
30
+
31
+
32
+ def __len__(self):
33
+ return self.length
34
+
35
+
36
+ def __getitem__(self, index):
37
+ image = self.pixel_values[index]
38
+ label = self.labels[index]
39
+
40
+ encoded_inputs = BatchFeature({"pixel_values": image, "labels": label})
41
+
42
+ return encoded_inputs
43
+
44
+
45
+ class SidewalkSegmentationDataLoader(pl.LightningDataModule):
46
+ def __init__(
47
+ self, hub_dir: str, batch_size: int, split: Optional[str] = None,
48
+ ):
49
+ super().__init__()
50
+ self.hub_dir = hub_dir
51
+ self.batch_size = batch_size
52
+ self.tokenizer = SegformerFeatureExtractor(reduce_labels=True)
53
+ self.dataset = load_dataset(self.hub_dir, split=split)
54
+ self.len = len(self.dataset)
55
+
56
+
57
+ def tokenize_data(self, *args, **kwargs):
58
+ return self.tokenizer(*args, **kwargs)
59
+
60
+
61
+ def setup(self, stage: str = None):
62
+ encoded_dataset = self.tokenize_data(
63
+ images=self.dataset["pixel_values"], segmentation_maps=self.dataset["label"], return_tensors="pt"
64
+ )
65
+ dataset = SegmentationDataset(encoded_dataset["pixel_values"], encoded_dataset["labels"])
66
+
67
+ indices = np.arange(self.len)
68
+ train_indices, val_indices = random_split(indices, [int(self.len * 0.8), int(self.len * 0.2)])
69
+
70
+ self.train_dataset = Subset(dataset, train_indices)
71
+ self.val_dataset = Subset(dataset, val_indices)
72
+
73
+
74
+ def train_dataloader(self):
75
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12)
76
+
77
+
78
+ def val_dataloader(self):
79
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12)