SunderAli17 commited on
Commit
2f74861
1 Parent(s): 7707ac5

Create dataset.py

Browse files
Files changed (1) hide show
  1. data/dataset.py +202 -0
data/dataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ from PIL import Image
5
+ from PIL.ImageOps import exif_transpose
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+ import json
9
+ import random
10
+ from facenet_pytorch import MTCNN
11
+ import torch
12
+
13
+ from utils.utils import extract_faces_and_landmarks, REFERNCE_FACIAL_POINTS_RELATIVE
14
+
15
+ def load_image(image_path: str) -> Image:
16
+ image = Image.open(image_path)
17
+ image = exif_transpose(image)
18
+ if not image.mode == "RGB":
19
+ image = image.convert("RGB")
20
+ return image
21
+
22
+
23
+ class ImageDataset(Dataset):
24
+ """
25
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
26
+ It pre-processes the images.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ instance_data_root,
32
+ instance_prompt,
33
+ metadata_path: Optional[str] = None,
34
+ prompt_in_filename=False,
35
+ use_only_vanilla_for_encoder=False,
36
+ concept_placeholder='a face',
37
+ size=1024,
38
+ center_crop=False,
39
+ aug_images=False,
40
+ use_only_decoder_prompts=False,
41
+ crop_head_for_encoder_image=False,
42
+ random_target_prob=0.0,
43
+ ):
44
+ self.mtcnn = MTCNN(device='cuda:0')
45
+ self.mtcnn.forward = self.mtcnn.detect
46
+ resize_factor = 1.3
47
+ self.resized_reference_points = REFERNCE_FACIAL_POINTS_RELATIVE / resize_factor + (resize_factor - 1) / (2 * resize_factor)
48
+ self.size = size
49
+ self.center_crop = center_crop
50
+ self.concept_placeholder = concept_placeholder
51
+ self.prompt_in_filename = prompt_in_filename
52
+ self.aug_images = aug_images
53
+
54
+ self.instance_prompt = instance_prompt
55
+ self.custom_instance_prompts = None
56
+ self.name_to_label = None
57
+ self.crop_head_for_encoder_image = crop_head_for_encoder_image
58
+ self.random_target_prob = random_target_prob
59
+
60
+ self.use_only_decoder_prompts = use_only_decoder_prompts
61
+
62
+ self.instance_data_root = Path(instance_data_root)
63
+
64
+ if not self.instance_data_root.exists():
65
+ raise ValueError(f"Instance images root {self.instance_data_root} doesn't exist.")
66
+
67
+ if metadata_path is not None:
68
+ with open(metadata_path, 'r') as f:
69
+ self.name_to_label = json.load(f) # dict of filename: label
70
+ # Create a reversed mapping
71
+ self.label_to_names = {}
72
+ for name, label in self.name_to_label.items():
73
+ if use_only_vanilla_for_encoder and 'vanilla' not in name:
74
+ continue
75
+ if label not in self.label_to_names:
76
+ self.label_to_names[label] = []
77
+ self.label_to_names[label].append(name)
78
+ self.all_paths = [self.instance_data_root / filename for filename in self.name_to_label.keys()]
79
+
80
+ # Verify all paths exist
81
+ n_all_paths = len(self.all_paths)
82
+ self.all_paths = [path for path in self.all_paths if path.exists()]
83
+ print(f'Found {len(self.all_paths)} out of {n_all_paths} paths.')
84
+ else:
85
+ self.all_paths = [path for path in list(Path(instance_data_root).glob('**/*')) if
86
+ path.suffix.lower() in [".png", ".jpg", ".jpeg"]]
87
+ # Sort by name so that order for validation remains the same across runs
88
+ self.all_paths = sorted(self.all_paths, key=lambda x: x.stem)
89
+
90
+ self.custom_instance_prompts = None
91
+
92
+ self._length = len(self.all_paths)
93
+
94
+ self.class_data_root = None
95
+
96
+ self.image_transforms = transforms.Compose(
97
+ [
98
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
99
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize([0.5], [0.5]),
102
+ ]
103
+ )
104
+
105
+ if self.prompt_in_filename:
106
+ self.prompts_set = set([self._path_to_prompt(path) for path in self.all_paths])
107
+ else:
108
+ self.prompts_set = set([self.instance_prompt])
109
+
110
+ if self.aug_images:
111
+ self.aug_transforms = transforms.Compose(
112
+ [
113
+ transforms.RandomResizedCrop(size, scale=(0.8, 1.0), ratio=(1.0, 1.0)),
114
+ transforms.RandomHorizontalFlip(p=0.5)
115
+ ]
116
+ )
117
+
118
+ def __len__(self):
119
+ return self._length
120
+
121
+ def _path_to_prompt(self, path):
122
+ # Remove the extension and seed
123
+ split_path = path.stem.split('_')
124
+ while split_path[-1].isnumeric():
125
+ split_path = split_path[:-1]
126
+
127
+ prompt = ' '.join(split_path)
128
+ # Replace placeholder in prompt with training placeholder
129
+ prompt = prompt.replace('conceptname', self.concept_placeholder)
130
+ return prompt
131
+
132
+ def __getitem__(self, index):
133
+ example = {}
134
+ instance_path = self.all_paths[index]
135
+ instance_image = load_image(instance_path)
136
+ example["instance_images"] = self.image_transforms(instance_image)
137
+ if self.prompt_in_filename:
138
+ example["instance_prompt"] = self._path_to_prompt(instance_path)
139
+ else:
140
+ example["instance_prompt"] = self.instance_prompt
141
+
142
+ if self.name_to_label is None:
143
+ # If no labels, simply take the same image but with different augmentation
144
+ example["encoder_images"] = self.aug_transforms(example["instance_images"]) if self.aug_images else example["instance_images"]
145
+ example["encoder_prompt"] = example["instance_prompt"]
146
+ else:
147
+ # Randomly select another image with the same label
148
+ instance_name = str(instance_path.relative_to(self.instance_data_root))
149
+ instance_label = self.name_to_label[instance_name]
150
+ label_set = set(self.label_to_names[instance_label])
151
+ if len(label_set) == 1:
152
+ # We are not supposed to have only one image per label, but just in case
153
+ encoder_image_name = instance_name
154
+ print(f'WARNING: Only one image for label {instance_label}.')
155
+ else:
156
+ encoder_image_name = random.choice(list(label_set - {instance_name}))
157
+ encoder_image = load_image(self.instance_data_root / encoder_image_name)
158
+ example["encoder_images"] = self.image_transforms(encoder_image)
159
+
160
+ if self.prompt_in_filename:
161
+ example["encoder_prompt"] = self._path_to_prompt(self.instance_data_root / encoder_image_name)
162
+ else:
163
+ example["encoder_prompt"] = self.instance_prompt
164
+
165
+ if self.crop_head_for_encoder_image:
166
+ example["encoder_images"] = extract_faces_and_landmarks(example["encoder_images"][None], self.size, self.mtcnn, self.resized_reference_points)[0][0]
167
+ example["encoder_prompt"] = example["encoder_prompt"].format(placeholder="<ph>")
168
+ example["instance_prompt"] = example["instance_prompt"].format(placeholder="<s*>")
169
+
170
+ if random.random() < self.random_target_prob:
171
+ random_path = random.choice(self.all_paths)
172
+
173
+ random_image = load_image(random_path)
174
+ example["instance_images"] = self.image_transforms(random_image)
175
+ if self.prompt_in_filename:
176
+ example["instance_prompt"] = self._path_to_prompt(random_path)
177
+
178
+
179
+ if self.use_only_decoder_prompts:
180
+ example["encoder_prompt"] = example["instance_prompt"]
181
+
182
+ return example
183
+
184
+
185
+ def collate_fn(examples, with_prior_preservation=False):
186
+ pixel_values = [example["instance_images"] for example in examples]
187
+ encoder_pixel_values = [example["encoder_images"] for example in examples]
188
+ prompts = [example["instance_prompt"] for example in examples]
189
+ encoder_prompts = [example["encoder_prompt"] for example in examples]
190
+
191
+ if with_prior_preservation:
192
+ raise NotImplementedError("Prior preservation not implemented.")
193
+
194
+ pixel_values = torch.stack(pixel_values)
195
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
196
+
197
+ encoder_pixel_values = torch.stack(encoder_pixel_values)
198
+ encoder_pixel_values = encoder_pixel_values.to(memory_format=torch.contiguous_format).float()
199
+
200
+ batch = {"pixel_values": pixel_values, "encoder_pixel_values": encoder_pixel_values,
201
+ "prompts": prompts, "encoder_prompts": encoder_prompts}
202
+ return batch