|
import random |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
from PIL import Image |
|
from torch import zeros_like |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
import glob |
|
from .preprocess_files import face_mask_google_mediapipe |
|
|
|
OBJECT_TEMPLATE = [ |
|
"a photo of a {}", |
|
"a rendering of a {}", |
|
"a cropped photo of the {}", |
|
"the photo of a {}", |
|
"a photo of a clean {}", |
|
"a photo of a dirty {}", |
|
"a dark photo of the {}", |
|
"a photo of my {}", |
|
"a photo of the cool {}", |
|
"a close-up photo of a {}", |
|
"a bright photo of the {}", |
|
"a cropped photo of a {}", |
|
"a photo of the {}", |
|
"a good photo of the {}", |
|
"a photo of one {}", |
|
"a close-up photo of the {}", |
|
"a rendition of the {}", |
|
"a photo of the clean {}", |
|
"a rendition of a {}", |
|
"a photo of a nice {}", |
|
"a good photo of a {}", |
|
"a photo of the nice {}", |
|
"a photo of the small {}", |
|
"a photo of the weird {}", |
|
"a photo of the large {}", |
|
"a photo of a cool {}", |
|
"a photo of a small {}", |
|
] |
|
|
|
STYLE_TEMPLATE = [ |
|
"a painting in the style of {}", |
|
"a rendering in the style of {}", |
|
"a cropped painting in the style of {}", |
|
"the painting in the style of {}", |
|
"a clean painting in the style of {}", |
|
"a dirty painting in the style of {}", |
|
"a dark painting in the style of {}", |
|
"a picture in the style of {}", |
|
"a cool painting in the style of {}", |
|
"a close-up painting in the style of {}", |
|
"a bright painting in the style of {}", |
|
"a cropped painting in the style of {}", |
|
"a good painting in the style of {}", |
|
"a close-up painting in the style of {}", |
|
"a rendition in the style of {}", |
|
"a nice painting in the style of {}", |
|
"a small painting in the style of {}", |
|
"a weird painting in the style of {}", |
|
"a large painting in the style of {}", |
|
] |
|
|
|
NULL_TEMPLATE = ["{}"] |
|
|
|
TEMPLATE_MAP = { |
|
"object": OBJECT_TEMPLATE, |
|
"style": STYLE_TEMPLATE, |
|
"null": NULL_TEMPLATE, |
|
} |
|
|
|
|
|
def _randomset(lis): |
|
ret = [] |
|
for i in range(len(lis)): |
|
if random.random() < 0.5: |
|
ret.append(lis[i]) |
|
return ret |
|
|
|
|
|
def _shuffle(lis): |
|
|
|
return random.sample(lis, len(lis)) |
|
|
|
|
|
def _get_cutout_holes( |
|
height, |
|
width, |
|
min_holes=8, |
|
max_holes=32, |
|
min_height=16, |
|
max_height=128, |
|
min_width=16, |
|
max_width=128, |
|
): |
|
holes = [] |
|
for _n in range(random.randint(min_holes, max_holes)): |
|
hole_height = random.randint(min_height, max_height) |
|
hole_width = random.randint(min_width, max_width) |
|
y1 = random.randint(0, height - hole_height) |
|
x1 = random.randint(0, width - hole_width) |
|
y2 = y1 + hole_height |
|
x2 = x1 + hole_width |
|
holes.append((x1, y1, x2, y2)) |
|
return holes |
|
|
|
|
|
def _generate_random_mask(image): |
|
mask = zeros_like(image[:1]) |
|
holes = _get_cutout_holes(mask.shape[1], mask.shape[2]) |
|
for (x1, y1, x2, y2) in holes: |
|
mask[:, y1:y2, x1:x2] = 1.0 |
|
if random.uniform(0, 1) < 0.25: |
|
mask.fill_(1.0) |
|
masked_image = image * (mask < 0.5) |
|
return mask, masked_image |
|
|
|
|
|
class PivotalTuningDatasetCapation(Dataset): |
|
""" |
|
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. |
|
It pre-processes the images and the tokenizes prompts. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
instance_data_root, |
|
tokenizer, |
|
token_map: Optional[dict] = None, |
|
use_template: Optional[str] = None, |
|
size=512, |
|
h_flip=True, |
|
color_jitter=False, |
|
resize=True, |
|
use_mask_captioned_data=False, |
|
use_face_segmentation_condition=False, |
|
train_inpainting=False, |
|
blur_amount: int = 70, |
|
): |
|
self.size = size |
|
self.tokenizer = tokenizer |
|
self.resize = resize |
|
self.train_inpainting = train_inpainting |
|
|
|
instance_data_root = Path(instance_data_root) |
|
if not instance_data_root.exists(): |
|
raise ValueError("Instance images root doesn't exists.") |
|
|
|
self.instance_images_path = [] |
|
self.mask_path = [] |
|
|
|
assert not ( |
|
use_mask_captioned_data and use_template |
|
), "Can't use both mask caption data and template." |
|
|
|
|
|
if use_mask_captioned_data: |
|
src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg") |
|
for f in src_imgs: |
|
idx = int(str(Path(f).stem).split(".")[0]) |
|
mask_path = f"{instance_data_root}/{idx}.mask.png" |
|
|
|
if Path(mask_path).exists(): |
|
self.instance_images_path.append(f) |
|
self.mask_path.append(mask_path) |
|
else: |
|
print(f"Mask not found for {f}") |
|
|
|
self.captions = open(f"{instance_data_root}/caption.txt").readlines() |
|
|
|
else: |
|
possibily_src_images = ( |
|
glob.glob(str(instance_data_root) + "/*.jpg") |
|
+ glob.glob(str(instance_data_root) + "/*.png") |
|
+ glob.glob(str(instance_data_root) + "/*.jpeg") |
|
) |
|
possibily_src_images = ( |
|
set(possibily_src_images) |
|
- set(glob.glob(str(instance_data_root) + "/*mask.png")) |
|
- set([str(instance_data_root) + "/caption.txt"]) |
|
) |
|
|
|
self.instance_images_path = list(set(possibily_src_images)) |
|
self.captions = [ |
|
x.split("/")[-1].split(".")[0] for x in self.instance_images_path |
|
] |
|
|
|
assert ( |
|
len(self.instance_images_path) > 0 |
|
), "No images found in the instance data root." |
|
|
|
self.instance_images_path = sorted(self.instance_images_path) |
|
|
|
self.use_mask = use_face_segmentation_condition or use_mask_captioned_data |
|
self.use_mask_captioned_data = use_mask_captioned_data |
|
|
|
if use_face_segmentation_condition: |
|
|
|
for idx in range(len(self.instance_images_path)): |
|
targ = f"{instance_data_root}/{idx}.mask.png" |
|
|
|
if not Path(targ).exists(): |
|
print(f"Mask not found for {targ}") |
|
|
|
print( |
|
"Warning : this will pre-process all the images in the instance data root." |
|
) |
|
|
|
if len(self.mask_path) > 0: |
|
print( |
|
"Warning : masks already exists, but will be overwritten." |
|
) |
|
|
|
masks = face_mask_google_mediapipe( |
|
[ |
|
Image.open(f).convert("RGB") |
|
for f in self.instance_images_path |
|
] |
|
) |
|
for idx, mask in enumerate(masks): |
|
mask.save(f"{instance_data_root}/{idx}.mask.png") |
|
|
|
break |
|
|
|
for idx in range(len(self.instance_images_path)): |
|
self.mask_path.append(f"{instance_data_root}/{idx}.mask.png") |
|
|
|
self.num_instance_images = len(self.instance_images_path) |
|
self.token_map = token_map |
|
|
|
self.use_template = use_template |
|
if use_template is not None: |
|
self.templates = TEMPLATE_MAP[use_template] |
|
|
|
self._length = self.num_instance_images |
|
|
|
self.h_flip = h_flip |
|
self.image_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
size, interpolation=transforms.InterpolationMode.BILINEAR |
|
) |
|
if resize |
|
else transforms.Lambda(lambda x: x), |
|
transforms.ColorJitter(0.1, 0.1) |
|
if color_jitter |
|
else transforms.Lambda(lambda x: x), |
|
transforms.CenterCrop(size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
|
|
self.blur_amount = blur_amount |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def __getitem__(self, index): |
|
example = {} |
|
instance_image = Image.open( |
|
self.instance_images_path[index % self.num_instance_images] |
|
) |
|
if not instance_image.mode == "RGB": |
|
instance_image = instance_image.convert("RGB") |
|
example["instance_images"] = self.image_transforms(instance_image) |
|
|
|
if self.train_inpainting: |
|
( |
|
example["instance_masks"], |
|
example["instance_masked_images"], |
|
) = _generate_random_mask(example["instance_images"]) |
|
|
|
if self.use_template: |
|
assert self.token_map is not None |
|
input_tok = list(self.token_map.values())[0] |
|
|
|
text = random.choice(self.templates).format(input_tok) |
|
else: |
|
text = self.captions[index % self.num_instance_images].strip() |
|
|
|
if self.token_map is not None: |
|
for token, value in self.token_map.items(): |
|
text = text.replace(token, value) |
|
|
|
print(text) |
|
|
|
if self.use_mask: |
|
example["mask"] = ( |
|
self.image_transforms( |
|
Image.open(self.mask_path[index % self.num_instance_images]) |
|
) |
|
* 0.5 |
|
+ 1.0 |
|
) |
|
|
|
if self.h_flip and random.random() > 0.5: |
|
hflip = transforms.RandomHorizontalFlip(p=1) |
|
|
|
example["instance_images"] = hflip(example["instance_images"]) |
|
if self.use_mask: |
|
example["mask"] = hflip(example["mask"]) |
|
|
|
example["instance_prompt_ids"] = self.tokenizer( |
|
text, |
|
padding="do_not_pad", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids |
|
|
|
return example |
|
|