import random from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from PIL import Image from torch import zeros_like from torch.utils.data import Dataset from torchvision import transforms import glob from lora_diffusion.preprocess_files import face_mask_google_mediapipe OBJECT_TEMPLATE = [ "a photo of a {}", "a rendering of a {}", "a cropped photo of the {}", "the photo of a {}", "a photo of a clean {}", "a photo of a dirty {}", "a dark photo of the {}", "a photo of my {}", "a photo of the cool {}", "a close-up photo of a {}", "a bright photo of the {}", "a cropped photo of a {}", "a photo of the {}", "a good photo of the {}", "a photo of one {}", "a close-up photo of the {}", "a rendition of the {}", "a photo of the clean {}", "a rendition of a {}", "a photo of a nice {}", "a good photo of a {}", "a photo of the nice {}", "a photo of the small {}", "a photo of the weird {}", "a photo of the large {}", "a photo of a cool {}", "a photo of a small {}", ] STYLE_TEMPLATE = [ "a painting in the style of {}", "a rendering in the style of {}", "a cropped painting in the style of {}", "the painting in the style of {}", "a clean painting in the style of {}", "a dirty painting in the style of {}", "a dark painting in the style of {}", "a picture in the style of {}", "a cool painting in the style of {}", "a close-up painting in the style of {}", "a bright painting in the style of {}", "a cropped painting in the style of {}", "a good painting in the style of {}", "a close-up painting in the style of {}", "a rendition in the style of {}", "a nice painting in the style of {}", "a small painting in the style of {}", "a weird painting in the style of {}", "a large painting in the style of {}", ] NULL_TEMPLATE = ["{}"] TEMPLATE_MAP = { "object": OBJECT_TEMPLATE, "style": STYLE_TEMPLATE, "null": NULL_TEMPLATE, } def _randomset(lis): ret = [] for i in range(len(lis)): if random.random() < 0.5: ret.append(lis[i]) return ret def _shuffle(lis): return random.sample(lis, len(lis)) def _get_cutout_holes( height, width, min_holes=8, max_holes=32, min_height=16, max_height=128, min_width=16, max_width=128, ): holes = [] for _n in range(random.randint(min_holes, max_holes)): hole_height = random.randint(min_height, max_height) hole_width = random.randint(min_width, max_width) y1 = random.randint(0, height - hole_height) x1 = random.randint(0, width - hole_width) y2 = y1 + hole_height x2 = x1 + hole_width holes.append((x1, y1, x2, y2)) return holes def _generate_random_mask(image): mask = zeros_like(image[:1]) holes = _get_cutout_holes(mask.shape[1], mask.shape[2]) for (x1, y1, x2, y2) in holes: mask[:, y1:y2, x1:x2] = 1.0 if random.uniform(0, 1) < 0.25: mask.fill_(1.0) masked_image = image * (mask < 0.5) return mask, masked_image class PivotalTuningDatasetCapation(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts. """ def __init__( self, images, caption, tokenizer, token_map: Optional[dict] = None, use_template: Optional[str] = None, size=512, h_flip=True, color_jitter=False, resize=True, use_mask_captioned_data=False, use_face_segmentation_condition=False, train_inpainting=False, blur_amount: int = 70, ): self.size = size self.tokenizer = tokenizer self.resize = resize self.train_inpainting = train_inpainting assert not ( use_mask_captioned_data and use_template ), "Can't use both mask caption data and template." # Prepare the instance images # self.instance_images_path = None self.images = images self.captions = [caption] * len(images) self.use_mask = use_face_segmentation_condition or use_mask_captioned_data self.use_mask_captioned_data = use_mask_captioned_data self.num_instance_images = len(self.images) self.token_map = token_map self.use_template = use_template if use_template is not None: self.templates = TEMPLATE_MAP[use_template] self._length = self.num_instance_images self.h_flip = h_flip self.image_transforms = transforms.Compose( [ transforms.Resize( size, interpolation=transforms.InterpolationMode.BILINEAR ) if resize else transforms.Lambda(lambda x: x), transforms.ColorJitter(0.1, 0.1) if color_jitter else transforms.Lambda(lambda x: x), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) self.blur_amount = blur_amount def __len__(self): return self._length def __getitem__(self, index): example = {} instance_image = self.images[index % self.num_instance_images] if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) if self.train_inpainting: ( example["instance_masks"], example["instance_masked_images"], ) = _generate_random_mask(example["instance_images"]) if self.use_template: assert self.token_map is not None input_tok = list(self.token_map.values())[0] text = random.choice(self.templates).format(input_tok) else: text = self.captions[index % self.num_instance_images].strip() if self.token_map is not None: for token, value in self.token_map.items(): text = text.replace(token, value) print(text) if self.use_mask: example["mask"] = ( self.image_transforms( Image.open(self.mask_path[index % self.num_instance_images]) ) * 0.5 + 1.0 ) if self.h_flip and random.random() > 0.5: hflip = transforms.RandomHorizontalFlip(p=1) example["instance_images"] = hflip(example["instance_images"]) if self.use_mask: example["mask"] = hflip(example["mask"]) example["instance_prompt_ids"] = self.tokenizer( text, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids return example