tfwang commited on
Commit
2651d59
1 Parent(s): d7db483

Update glide_text2im/image_datasets_sketch.py

Browse files
glide_text2im/image_datasets_sketch.py CHANGED
@@ -3,7 +3,7 @@ import random
3
 
4
  from PIL import Image
5
  import blobfile as bf
6
- from mpi4py import MPI
7
  import numpy as np
8
  from torch.utils.data import DataLoader, Dataset
9
  import os
@@ -13,169 +13,7 @@ from .degradation.bsrgan_light import degradation_bsrgan_variant as degradation_
13
  from functools import partial
14
  import cv2
15
 
16
- from PIL import PngImagePlugin
17
- LARGE_ENOUGH_NUMBER = 100
18
- PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
19
 
20
- def load_data_sketch(
21
- *,
22
- data_dir,
23
- batch_size,
24
- image_size,
25
- class_cond=False,
26
- deterministic=False,
27
- random_crop=False,
28
- random_flip=True,
29
- train=True,
30
- low_res = 0,
31
- uncond_p = 0,
32
- mode = ''
33
- ):
34
- """
35
- For a dataset, create a generator over (images, kwargs) pairs.
36
-
37
- Each images is an NCHW float tensor, and the kwargs dict contains zero or
38
- more keys, each of which map to a batched Tensor of their own.
39
- The kwargs dict can be used for class labels, in which case the key is "y"
40
- and the values are integer tensors of class labels.
41
-
42
- :param data_dir: a dataset directory.
43
- :param batch_size: the batch size of each returned pair.
44
- :param image_size: the size to which images are resized.
45
- :param class_cond: if True, include a "y" key in returned dicts for class
46
- label. If classes are not available and this is true, an
47
- exception will be raised.
48
- :param deterministic: if True, yield results in a deterministic order.
49
- :param random_crop: if True, randomly crop the images for augmentation.
50
- :param random_flip: if True, randomly flip the images for augmentation.
51
- """
52
- if not data_dir:
53
- raise ValueError("unspecified data directory")
54
- with open(data_dir) as f:
55
- all_files = f.read().splitlines()
56
-
57
- print(len(all_files))
58
- classes = None
59
- if class_cond:
60
- # Assume classes are the first part of the filename,
61
- # before an underscore.
62
- class_names = [bf.basename(path).split("_")[0] for path in all_files]
63
- sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
64
- classes = [sorted_classes[x] for x in class_names]
65
- dataset = ImageDataset(
66
- image_size,
67
- all_files,
68
- classes=classes,
69
- shard=MPI.COMM_WORLD.Get_rank(),
70
- num_shards=MPI.COMM_WORLD.Get_size(),
71
- random_crop=random_crop,
72
- random_flip=train,
73
- down_sample_img_size = low_res,
74
- uncond_p = uncond_p,
75
- mode = mode,
76
- )
77
- if deterministic:
78
- loader = DataLoader(
79
- dataset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True, pin_memory=False
80
- )
81
- else:
82
- loader = DataLoader(
83
- dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True, pin_memory=False
84
- )
85
- while True:
86
- yield from loader
87
-
88
- def _list_image_files_recursively(data_dir):
89
- results = []
90
- for entry in sorted(bf.listdir(data_dir)):
91
- full_path = bf.join(data_dir, entry)
92
- ext = entry.split(".")[-1]
93
- if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
94
- results.append(full_path)
95
- elif bf.isdir(full_path):
96
- results.extend(_list_image_files_recursively(full_path))
97
- return results
98
-
99
- class ImageDataset(Dataset):
100
- def __init__(
101
- self,
102
- resolution,
103
- image_paths,
104
- classes=None,
105
- shard=0,
106
- num_shards=1,
107
- random_crop=False,
108
- random_flip=True,
109
- down_sample_img_size = 0,
110
- uncond_p = 0,
111
- mode = '',
112
- ):
113
- super().__init__()
114
- self.crop_size = 256
115
- self.resize_size = 256
116
- self.local_images = image_paths[shard:][::num_shards]
117
- self.local_classes = None if classes is None else classes[shard:][::num_shards]
118
- self.random_crop = random_crop
119
- self.random_flip = random_flip
120
-
121
- self.down_sample_img = partial(degradation_fn_bsr_light, sf=resolution//down_sample_img_size) if down_sample_img_size else None
122
- self.uncond_p = uncond_p
123
- self.mode = mode
124
- self.resolution = resolution
125
-
126
- def __len__(self):
127
- return len(self.local_images)
128
-
129
- def __getitem__(self, idx):
130
- if self.mode == 'coco-edge':
131
- path = self.local_images[idx].replace('COCO-STUFF', 'COCO-Sketch')[:-4] + '.png'
132
- path2 = path.replace('_img', '_sketch')
133
- elif self.mode == 'flickr-edge':
134
- path = self.local_images[idx].replace('images', 'img256')[:-4] + '.png'
135
- path2 = path.replace('img256', 'sketch256')
136
-
137
-
138
- with bf.BlobFile(path, "rb") as f:
139
- pil_image = Image.open(f)
140
- pil_image.load()
141
- pil_image = pil_image.convert("RGB")
142
-
143
-
144
- with bf.BlobFile(path2, "rb") as f:
145
- pil_image2 = Image.open(f)
146
- pil_image2.load()
147
- pil_image2 = pil_image2.convert("L")
148
-
149
-
150
- params = get_params(pil_image2.size, self.resize_size, self.crop_size)
151
- transform_label = get_transform(params, self.resize_size, self.crop_size, method=Image.NEAREST, crop =self.random_crop, flip=self.random_flip)
152
- label_pil = transform_label(pil_image2)
153
-
154
- im_dist = cv2.distanceTransform(255-np.array(label_pil), cv2.DIST_L1, 3)
155
- im_dist = np.clip((im_dist) , 0, 255).astype(np.uint8)
156
- im_dist = Image.fromarray(im_dist).convert("RGB")
157
-
158
- label_tensor = get_tensor()(im_dist)[:1]
159
- label_tensor_ori = get_tensor()(label_pil.convert('RGB'))
160
-
161
- transform_image = get_transform( params, self.resize_size, self.crop_size, crop =self.random_crop, flip=self.random_flip)
162
- image_pil = transform_image(pil_image)
163
- if self.resolution < 256:
164
- image_pil = image_pil.resize((self.resolution, self.resolution), Image.BICUBIC)
165
- image_tensor = get_tensor()(image_pil)
166
-
167
- if self.down_sample_img:
168
- image_pil = np.array(image_pil).astype(np.uint8)
169
- down_sampled_image = self.down_sample_img(image=image_pil)["image"]
170
- down_sampled_image = get_tensor()(down_sampled_image)
171
- data_dict = {"ref":label_tensor, "low_res":down_sampled_image, "ref_ori":label_tensor_ori, "path": path}
172
- return image_tensor, data_dict
173
-
174
- if random.random() < self.uncond_p:
175
- label_tensor = th.ones_like(label_tensor)
176
- data_dict = {"ref":label_tensor, "ref_ori":label_tensor_ori, "path": path}
177
-
178
- return image_tensor, data_dict
179
 
180
  def get_params( size, resize_size, crop_size):
181
  w, h = size
 
3
 
4
  from PIL import Image
5
  import blobfile as bf
6
+ #from mpi4py import MPI
7
  import numpy as np
8
  from torch.utils.data import DataLoader, Dataset
9
  import os
 
13
  from functools import partial
14
  import cv2
15
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def get_params( size, resize_size, crop_size):
19
  w, h = size