|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class ImagesDataset(Dataset): |
|
def __init__(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor]): |
|
if isinstance(images, list): |
|
images = dict.fromkeys(images) |
|
|
|
self.images = list(images) |
|
self.names = list(images.values()) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, index): |
|
image = self.images[index] |
|
|
|
if image.dtype is torch.uint8: |
|
image = image / 255 |
|
|
|
names = self.names[index] |
|
return image, names |
|
|
|
|
|
def image_collate(batch): |
|
images = torch.stack([item[0] for item in batch]) |
|
names = [item[1] for item in batch] |
|
return images, names |
|
|