|
import random |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
from PIL import Image |
|
from torch import zeros_like |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
import glob |
|
from lora_diffusion.preprocess_files import face_mask_google_mediapipe |
|
|
|
OBJECT_TEMPLATE = [ |
|
"a photo of a {}", |
|
"a rendering of a {}", |
|
"a cropped photo of the {}", |
|
"the photo of a {}", |
|
"a photo of a clean {}", |
|
"a photo of a dirty {}", |
|
"a dark photo of the {}", |
|
"a photo of my {}", |
|
"a photo of the cool {}", |
|
"a close-up photo of a {}", |
|
"a bright photo of the {}", |
|
"a cropped photo of a {}", |
|
"a photo of the {}", |
|
"a good photo of the {}", |
|
"a photo of one {}", |
|
"a close-up photo of the {}", |
|
"a rendition of the {}", |
|
"a photo of the clean {}", |
|
"a rendition of a {}", |
|
"a photo of a nice {}", |
|
"a good photo of a {}", |
|
"a photo of the nice {}", |
|
"a photo of the small {}", |
|
"a photo of the weird {}", |
|
"a photo of the large {}", |
|
"a photo of a cool {}", |
|
"a photo of a small {}", |
|
] |
|
|
|
STYLE_TEMPLATE = [ |
|
"a painting in the style of {}", |
|
"a rendering in the style of {}", |
|
"a cropped painting in the style of {}", |
|
"the painting in the style of {}", |
|
"a clean painting in the style of {}", |
|
"a dirty painting in the style of {}", |
|
"a dark painting in the style of {}", |
|
"a picture in the style of {}", |
|
"a cool painting in the style of {}", |
|
"a close-up painting in the style of {}", |
|
"a bright painting in the style of {}", |
|
"a cropped painting in the style of {}", |
|
"a good painting in the style of {}", |
|
"a close-up painting in the style of {}", |
|
"a rendition in the style of {}", |
|
"a nice painting in the style of {}", |
|
"a small painting in the style of {}", |
|
"a weird painting in the style of {}", |
|
"a large painting in the style of {}", |
|
] |
|
|
|
NULL_TEMPLATE = ["{}"] |
|
|
|
TEMPLATE_MAP = { |
|
"object": OBJECT_TEMPLATE, |
|
"style": STYLE_TEMPLATE, |
|
"null": NULL_TEMPLATE, |
|
} |
|
|
|
|
|
def _randomset(lis): |
|
ret = [] |
|
for i in range(len(lis)): |
|
if random.random() < 0.5: |
|
ret.append(lis[i]) |
|
return ret |
|
|
|
|
|
def _shuffle(lis): |
|
|
|
return random.sample(lis, len(lis)) |
|
|
|
|
|
def _get_cutout_holes( |
|
height, |
|
width, |
|
min_holes=8, |
|
max_holes=32, |
|
min_height=16, |
|
max_height=128, |
|
min_width=16, |
|
max_width=128, |
|
): |
|
holes = [] |
|
for _n in range(random.randint(min_holes, max_holes)): |
|
hole_height = random.randint(min_height, max_height) |
|
hole_width = random.randint(min_width, max_width) |
|
y1 = random.randint(0, height - hole_height) |
|
x1 = random.randint(0, width - hole_width) |
|
y2 = y1 + hole_height |
|
x2 = x1 + hole_width |
|
holes.append((x1, y1, x2, y2)) |
|
return holes |
|
|
|
|
|
def _generate_random_mask(image): |
|
mask = zeros_like(image[:1]) |
|
holes = _get_cutout_holes(mask.shape[1], mask.shape[2]) |
|
for (x1, y1, x2, y2) in holes: |
|
mask[:, y1:y2, x1:x2] = 1.0 |
|
if random.uniform(0, 1) < 0.25: |
|
mask.fill_(1.0) |
|
masked_image = image * (mask < 0.5) |
|
return mask, masked_image |
|
|
|
|
|
class PivotalTuningDatasetCapation(Dataset): |
|
""" |
|
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. |
|
It pre-processes the images and the tokenizes prompts. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
images, |
|
caption, |
|
tokenizer, |
|
token_map: Optional[dict] = None, |
|
use_template: Optional[str] = None, |
|
size=512, |
|
h_flip=True, |
|
color_jitter=False, |
|
resize=True, |
|
use_mask_captioned_data=False, |
|
use_face_segmentation_condition=False, |
|
train_inpainting=False, |
|
blur_amount: int = 70, |
|
): |
|
self.size = size |
|
self.tokenizer = tokenizer |
|
self.resize = resize |
|
self.train_inpainting = train_inpainting |
|
|
|
assert not ( |
|
use_mask_captioned_data and use_template |
|
), "Can't use both mask caption data and template." |
|
|
|
|
|
|
|
self.images = images |
|
self.captions = [caption] * len(images) |
|
|
|
self.use_mask = use_face_segmentation_condition or use_mask_captioned_data |
|
self.use_mask_captioned_data = use_mask_captioned_data |
|
|
|
self.num_instance_images = len(self.images) |
|
self.token_map = token_map |
|
|
|
self.use_template = use_template |
|
if use_template is not None: |
|
self.templates = TEMPLATE_MAP[use_template] |
|
|
|
self._length = self.num_instance_images |
|
|
|
self.h_flip = h_flip |
|
self.image_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
size, interpolation=transforms.InterpolationMode.BILINEAR |
|
) |
|
if resize |
|
else transforms.Lambda(lambda x: x), |
|
transforms.ColorJitter(0.1, 0.1) |
|
if color_jitter |
|
else transforms.Lambda(lambda x: x), |
|
transforms.CenterCrop(size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
|
|
self.blur_amount = blur_amount |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def __getitem__(self, index): |
|
|
|
example = {} |
|
instance_image = self.images[index % self.num_instance_images] |
|
if not instance_image.mode == "RGB": |
|
instance_image = instance_image.convert("RGB") |
|
example["instance_images"] = self.image_transforms(instance_image) |
|
|
|
if self.train_inpainting: |
|
( |
|
example["instance_masks"], |
|
example["instance_masked_images"], |
|
) = _generate_random_mask(example["instance_images"]) |
|
|
|
if self.use_template: |
|
assert self.token_map is not None |
|
|
|
input_tok = list(self.token_map.values())[0] |
|
|
|
text = random.choice(self.templates).format(input_tok) |
|
|
|
else: |
|
text = self.captions[index % self.num_instance_images].strip() |
|
|
|
if self.token_map is not None: |
|
for token, value in self.token_map.items(): |
|
text = text.replace(token, value) |
|
|
|
print(text) |
|
|
|
if self.use_mask: |
|
example["mask"] = ( |
|
self.image_transforms( |
|
Image.open(self.mask_path[index % self.num_instance_images]) |
|
) |
|
* 0.5 |
|
+ 1.0 |
|
) |
|
|
|
if self.h_flip and random.random() > 0.5: |
|
hflip = transforms.RandomHorizontalFlip(p=1) |
|
|
|
example["instance_images"] = hflip(example["instance_images"]) |
|
if self.use_mask: |
|
example["mask"] = hflip(example["mask"]) |
|
|
|
example["instance_prompt_ids"] = self.tokenizer( |
|
text, |
|
padding="do_not_pad", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids |
|
|
|
return example |
|
|