meeww commited on
Commit
c5c5c1d
1 Parent(s): 89eaa78

Upload image_dataset.py

Browse files
Files changed (1) hide show
  1. image_dataset.py +24 -0
image_dataset.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join
3
+
4
+ import torchvision.transforms as transforms
5
+ from torch.utils.data import Dataset
6
+
7
+ from PIL import Image
8
+
9
+
10
+ class ImageDataset(Dataset):
11
+ def __init__(self, folder_path):
12
+ self.transform = transforms.Compose([
13
+ transforms.ToTensor(),
14
+ transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5))
15
+ ])
16
+ self.files = [
17
+ join(folder_path, file) for file in os.listdir(folder_path)
18
+ ]
19
+
20
+ def __getitem__(self, index):
21
+ return self.transform(Image.open(self.files[index % len(self.files)]))
22
+
23
+ def __len__(self):
24
+ return len(self.files)