Spaces:
Runtime error
Runtime error
add dps data
Browse files- diffusion-posterior-sampling/data/__pycache__/dataloader.cpython-38.pyc +0 -0
- diffusion-posterior-sampling/data/dataloader.py +55 -0
- diffusion-posterior-sampling/data/samples/00000.png +0 -0
- diffusion-posterior-sampling/data/samples/00001.png +0 -0
- diffusion-posterior-sampling/data/samples/00003.png +0 -0
- diffusion-posterior-sampling/data/samples/00004.png +0 -0
- diffusion-posterior-sampling/data/samples/00008.png +0 -0
- diffusion-posterior-sampling/data/samples/00014.png +0 -0
- diffusion-posterior-sampling/data/samples/00015.png +0 -0
- diffusion-posterior-sampling/data/samples/00017.png +0 -0
- diffusion-posterior-sampling/data/samples/00019.png +0 -0
- diffusion-posterior-sampling/data/samples/00024.png +0 -0
- diffusion-posterior-sampling/data/samples/00048.png +0 -0
- diffusion-posterior-sampling/data/samples/00261.png +0 -0
- diffusion-posterior-sampling/data/samples/00478.png +0 -0
- diffusion-posterior-sampling/data/samples/00535.png +0 -0
diffusion-posterior-sampling/data/__pycache__/dataloader.cpython-38.pyc
ADDED
Binary file (2.23 kB). View file
|
|
diffusion-posterior-sampling/data/dataloader.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
from PIL import Image
|
3 |
+
from typing import Callable, Optional
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torchvision.datasets import VisionDataset
|
6 |
+
|
7 |
+
|
8 |
+
__DATASET__ = {}
|
9 |
+
|
10 |
+
def register_dataset(name: str):
|
11 |
+
def wrapper(cls):
|
12 |
+
if __DATASET__.get(name, None):
|
13 |
+
raise NameError(f"Name {name} is already registered!")
|
14 |
+
__DATASET__[name] = cls
|
15 |
+
return cls
|
16 |
+
return wrapper
|
17 |
+
|
18 |
+
|
19 |
+
def get_dataset(name: str, root: str, **kwargs):
|
20 |
+
if __DATASET__.get(name, None) is None:
|
21 |
+
raise NameError(f"Dataset {name} is not defined.")
|
22 |
+
return __DATASET__[name](root=root, **kwargs)
|
23 |
+
|
24 |
+
|
25 |
+
def get_dataloader(dataset: VisionDataset,
|
26 |
+
batch_size: int,
|
27 |
+
num_workers: int,
|
28 |
+
train: bool):
|
29 |
+
dataloader = DataLoader(dataset,
|
30 |
+
batch_size,
|
31 |
+
shuffle=train,
|
32 |
+
num_workers=num_workers,
|
33 |
+
drop_last=train)
|
34 |
+
return dataloader
|
35 |
+
|
36 |
+
|
37 |
+
@register_dataset(name='ffhq')
|
38 |
+
class FFHQDataset(VisionDataset):
|
39 |
+
def __init__(self, root: str, transforms: Optional[Callable]=None):
|
40 |
+
super().__init__(root, transforms)
|
41 |
+
|
42 |
+
self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
|
43 |
+
assert len(self.fpaths) > 0, "File list is empty. Check the root."
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.fpaths)
|
47 |
+
|
48 |
+
def __getitem__(self, index: int):
|
49 |
+
fpath = self.fpaths[index]
|
50 |
+
img = Image.open(fpath).convert('RGB')
|
51 |
+
|
52 |
+
if self.transforms is not None:
|
53 |
+
img = self.transforms(img)
|
54 |
+
|
55 |
+
return img
|
diffusion-posterior-sampling/data/samples/00000.png
ADDED
diffusion-posterior-sampling/data/samples/00001.png
ADDED
diffusion-posterior-sampling/data/samples/00003.png
ADDED
diffusion-posterior-sampling/data/samples/00004.png
ADDED
diffusion-posterior-sampling/data/samples/00008.png
ADDED
diffusion-posterior-sampling/data/samples/00014.png
ADDED
diffusion-posterior-sampling/data/samples/00015.png
ADDED
diffusion-posterior-sampling/data/samples/00017.png
ADDED
diffusion-posterior-sampling/data/samples/00019.png
ADDED
diffusion-posterior-sampling/data/samples/00024.png
ADDED
diffusion-posterior-sampling/data/samples/00048.png
ADDED
diffusion-posterior-sampling/data/samples/00261.png
ADDED
diffusion-posterior-sampling/data/samples/00478.png
ADDED
diffusion-posterior-sampling/data/samples/00535.png
ADDED