File size: 8,503 Bytes
2f74861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from pathlib import Path
from typing import Optional

from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
import json
import random
from facenet_pytorch import MTCNN
import torch

from utils.utils import extract_faces_and_landmarks, REFERNCE_FACIAL_POINTS_RELATIVE

def load_image(image_path: str) -> Image:
    image = Image.open(image_path)
    image = exif_transpose(image)
    if not image.mode == "RGB":
        image = image.convert("RGB")
    return image


class ImageDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images.
    """

    def __init__(
            self,
            instance_data_root,
            instance_prompt,
            metadata_path: Optional[str] = None,
            prompt_in_filename=False,
            use_only_vanilla_for_encoder=False,
            concept_placeholder='a face',
            size=1024,
            center_crop=False,
            aug_images=False,
            use_only_decoder_prompts=False,
            crop_head_for_encoder_image=False,
            random_target_prob=0.0,
    ):
        self.mtcnn = MTCNN(device='cuda:0')
        self.mtcnn.forward = self.mtcnn.detect
        resize_factor = 1.3
        self.resized_reference_points = REFERNCE_FACIAL_POINTS_RELATIVE / resize_factor + (resize_factor - 1) / (2 * resize_factor) 
        self.size = size
        self.center_crop = center_crop
        self.concept_placeholder = concept_placeholder
        self.prompt_in_filename = prompt_in_filename
        self.aug_images = aug_images

        self.instance_prompt = instance_prompt
        self.custom_instance_prompts = None
        self.name_to_label = None
        self.crop_head_for_encoder_image = crop_head_for_encoder_image
        self.random_target_prob = random_target_prob

        self.use_only_decoder_prompts = use_only_decoder_prompts

        self.instance_data_root = Path(instance_data_root)

        if not self.instance_data_root.exists():
            raise ValueError(f"Instance images root {self.instance_data_root} doesn't exist.")

        if metadata_path is not None:
            with open(metadata_path, 'r') as f:
                self.name_to_label = json.load(f)  # dict of filename: label
            # Create a reversed mapping
            self.label_to_names = {}
            for name, label in self.name_to_label.items():
                if use_only_vanilla_for_encoder and 'vanilla' not in name:
                    continue
                if label not in self.label_to_names:
                    self.label_to_names[label] = []
                self.label_to_names[label].append(name)
            self.all_paths = [self.instance_data_root / filename for filename in self.name_to_label.keys()]

            # Verify all paths exist
            n_all_paths = len(self.all_paths)
            self.all_paths = [path for path in self.all_paths if path.exists()]
            print(f'Found {len(self.all_paths)} out of {n_all_paths} paths.')
        else:
            self.all_paths = [path for path in list(Path(instance_data_root).glob('**/*')) if
                              path.suffix.lower() in [".png", ".jpg", ".jpeg"]]
            # Sort by name so that order for validation remains the same across runs
            self.all_paths = sorted(self.all_paths, key=lambda x: x.stem)

        self.custom_instance_prompts = None

        self._length = len(self.all_paths)

        self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        if self.prompt_in_filename:
            self.prompts_set = set([self._path_to_prompt(path) for path in self.all_paths])
        else:
            self.prompts_set = set([self.instance_prompt])

        if self.aug_images:
            self.aug_transforms = transforms.Compose(
                [
                    transforms.RandomResizedCrop(size, scale=(0.8, 1.0), ratio=(1.0, 1.0)),
                    transforms.RandomHorizontalFlip(p=0.5)
                ]
            )

    def __len__(self):
        return self._length

    def _path_to_prompt(self, path):
        # Remove the extension and seed
        split_path = path.stem.split('_')
        while split_path[-1].isnumeric():
            split_path = split_path[:-1]

        prompt = ' '.join(split_path)
        # Replace placeholder in prompt with training placeholder
        prompt = prompt.replace('conceptname', self.concept_placeholder)
        return prompt

    def __getitem__(self, index):
        example = {}
        instance_path = self.all_paths[index]
        instance_image = load_image(instance_path)
        example["instance_images"] = self.image_transforms(instance_image)
        if self.prompt_in_filename:
            example["instance_prompt"] = self._path_to_prompt(instance_path)
        else:
            example["instance_prompt"] = self.instance_prompt

        if self.name_to_label is None:
            # If no labels, simply take the same image but with different augmentation
            example["encoder_images"] = self.aug_transforms(example["instance_images"]) if self.aug_images else example["instance_images"]
            example["encoder_prompt"] = example["instance_prompt"]
        else:
            # Randomly select another image with the same label
            instance_name = str(instance_path.relative_to(self.instance_data_root))
            instance_label = self.name_to_label[instance_name]
            label_set = set(self.label_to_names[instance_label])
            if len(label_set) == 1:
                # We are not supposed to have only one image per label, but just in case
                encoder_image_name = instance_name
                print(f'WARNING: Only one image for label {instance_label}.')
            else:
                encoder_image_name = random.choice(list(label_set - {instance_name}))
            encoder_image = load_image(self.instance_data_root / encoder_image_name)
            example["encoder_images"] = self.image_transforms(encoder_image)

            if self.prompt_in_filename:
                example["encoder_prompt"] = self._path_to_prompt(self.instance_data_root / encoder_image_name)
            else:
                example["encoder_prompt"] = self.instance_prompt
        
        if self.crop_head_for_encoder_image:
            example["encoder_images"] = extract_faces_and_landmarks(example["encoder_images"][None], self.size, self.mtcnn, self.resized_reference_points)[0][0]
        example["encoder_prompt"]  = example["encoder_prompt"].format(placeholder="<ph>")
        example["instance_prompt"] = example["instance_prompt"].format(placeholder="<s*>")

        if random.random() < self.random_target_prob:
            random_path = random.choice(self.all_paths)

            random_image = load_image(random_path)
            example["instance_images"] = self.image_transforms(random_image)
            if self.prompt_in_filename:
                example["instance_prompt"] = self._path_to_prompt(random_path)


        if self.use_only_decoder_prompts:
            example["encoder_prompt"] = example["instance_prompt"]

        return example


def collate_fn(examples, with_prior_preservation=False):
    pixel_values = [example["instance_images"] for example in examples]
    encoder_pixel_values = [example["encoder_images"] for example in examples]
    prompts = [example["instance_prompt"] for example in examples]
    encoder_prompts = [example["encoder_prompt"] for example in examples]

    if with_prior_preservation:
        raise NotImplementedError("Prior preservation not implemented.")

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    encoder_pixel_values = torch.stack(encoder_pixel_values)
    encoder_pixel_values = encoder_pixel_values.to(memory_format=torch.contiguous_format).float()

    batch = {"pixel_values": pixel_values, "encoder_pixel_values": encoder_pixel_values,
             "prompts": prompts, "encoder_prompts": encoder_prompts}
    return batch