Spaces:
Runtime error
Runtime error
import random | |
from pathlib import Path | |
from typing import Dict, List, Optional, Tuple, Union | |
from PIL import Image | |
from torch import zeros_like | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
import glob | |
from .preprocess_files import face_mask_google_mediapipe | |
OBJECT_TEMPLATE = [ | |
"a photo of a {}", | |
"a rendering of a {}", | |
"a cropped photo of the {}", | |
"the photo of a {}", | |
"a photo of a clean {}", | |
"a photo of a dirty {}", | |
"a dark photo of the {}", | |
"a photo of my {}", | |
"a photo of the cool {}", | |
"a close-up photo of a {}", | |
"a bright photo of the {}", | |
"a cropped photo of a {}", | |
"a photo of the {}", | |
"a good photo of the {}", | |
"a photo of one {}", | |
"a close-up photo of the {}", | |
"a rendition of the {}", | |
"a photo of the clean {}", | |
"a rendition of a {}", | |
"a photo of a nice {}", | |
"a good photo of a {}", | |
"a photo of the nice {}", | |
"a photo of the small {}", | |
"a photo of the weird {}", | |
"a photo of the large {}", | |
"a photo of a cool {}", | |
"a photo of a small {}", | |
] | |
STYLE_TEMPLATE = [ | |
"a painting in the style of {}", | |
"a rendering in the style of {}", | |
"a cropped painting in the style of {}", | |
"the painting in the style of {}", | |
"a clean painting in the style of {}", | |
"a dirty painting in the style of {}", | |
"a dark painting in the style of {}", | |
"a picture in the style of {}", | |
"a cool painting in the style of {}", | |
"a close-up painting in the style of {}", | |
"a bright painting in the style of {}", | |
"a cropped painting in the style of {}", | |
"a good painting in the style of {}", | |
"a close-up painting in the style of {}", | |
"a rendition in the style of {}", | |
"a nice painting in the style of {}", | |
"a small painting in the style of {}", | |
"a weird painting in the style of {}", | |
"a large painting in the style of {}", | |
] | |
NULL_TEMPLATE = ["{}"] | |
TEMPLATE_MAP = { | |
"object": OBJECT_TEMPLATE, | |
"style": STYLE_TEMPLATE, | |
"null": NULL_TEMPLATE, | |
} | |
def _randomset(lis): | |
ret = [] | |
for i in range(len(lis)): | |
if random.random() < 0.5: | |
ret.append(lis[i]) | |
return ret | |
def _shuffle(lis): | |
return random.sample(lis, len(lis)) | |
def _get_cutout_holes( | |
height, | |
width, | |
min_holes=8, | |
max_holes=32, | |
min_height=16, | |
max_height=128, | |
min_width=16, | |
max_width=128, | |
): | |
holes = [] | |
for _n in range(random.randint(min_holes, max_holes)): | |
hole_height = random.randint(min_height, max_height) | |
hole_width = random.randint(min_width, max_width) | |
y1 = random.randint(0, height - hole_height) | |
x1 = random.randint(0, width - hole_width) | |
y2 = y1 + hole_height | |
x2 = x1 + hole_width | |
holes.append((x1, y1, x2, y2)) | |
return holes | |
def _generate_random_mask(image): | |
mask = zeros_like(image[:1]) | |
holes = _get_cutout_holes(mask.shape[1], mask.shape[2]) | |
for (x1, y1, x2, y2) in holes: | |
mask[:, y1:y2, x1:x2] = 1.0 | |
if random.uniform(0, 1) < 0.25: | |
mask.fill_(1.0) | |
masked_image = image * (mask < 0.5) | |
return mask, masked_image | |
class PivotalTuningDatasetCapation(Dataset): | |
""" | |
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. | |
It pre-processes the images and the tokenizes prompts. | |
""" | |
def __init__( | |
self, | |
instance_data_root, | |
tokenizer, | |
token_map: Optional[dict] = None, | |
use_template: Optional[str] = None, | |
size=512, | |
h_flip=True, | |
color_jitter=False, | |
resize=True, | |
use_mask_captioned_data=False, | |
use_face_segmentation_condition=False, | |
train_inpainting=False, | |
blur_amount: int = 70, | |
): | |
self.size = size | |
self.tokenizer = tokenizer | |
self.resize = resize | |
self.train_inpainting = train_inpainting | |
instance_data_root = Path(instance_data_root) | |
if not instance_data_root.exists(): | |
raise ValueError("Instance images root doesn't exists.") | |
self.instance_images_path = [] | |
self.mask_path = [] | |
assert not ( | |
use_mask_captioned_data and use_template | |
), "Can't use both mask caption data and template." | |
# Prepare the instance images | |
if use_mask_captioned_data: | |
src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg") | |
for f in src_imgs: | |
idx = int(str(Path(f).stem).split(".")[0]) | |
mask_path = f"{instance_data_root}/{idx}.mask.png" | |
if Path(mask_path).exists(): | |
self.instance_images_path.append(f) | |
self.mask_path.append(mask_path) | |
else: | |
print(f"Mask not found for {f}") | |
self.captions = open(f"{instance_data_root}/caption.txt").readlines() | |
else: | |
possibily_src_images = ( | |
glob.glob(str(instance_data_root) + "/*.jpg") | |
+ glob.glob(str(instance_data_root) + "/*.png") | |
+ glob.glob(str(instance_data_root) + "/*.jpeg") | |
) | |
possibily_src_images = ( | |
set(possibily_src_images) | |
- set(glob.glob(str(instance_data_root) + "/*mask.png")) | |
- set([str(instance_data_root) + "/caption.txt"]) | |
) | |
self.instance_images_path = list(set(possibily_src_images)) | |
self.captions = [ | |
x.split("/")[-1].split(".")[0] for x in self.instance_images_path | |
] | |
assert ( | |
len(self.instance_images_path) > 0 | |
), "No images found in the instance data root." | |
self.instance_images_path = sorted(self.instance_images_path) | |
self.use_mask = use_face_segmentation_condition or use_mask_captioned_data | |
self.use_mask_captioned_data = use_mask_captioned_data | |
if use_face_segmentation_condition: | |
for idx in range(len(self.instance_images_path)): | |
targ = f"{instance_data_root}/{idx}.mask.png" | |
# see if the mask exists | |
if not Path(targ).exists(): | |
print(f"Mask not found for {targ}") | |
print( | |
"Warning : this will pre-process all the images in the instance data root." | |
) | |
if len(self.mask_path) > 0: | |
print( | |
"Warning : masks already exists, but will be overwritten." | |
) | |
masks = face_mask_google_mediapipe( | |
[ | |
Image.open(f).convert("RGB") | |
for f in self.instance_images_path | |
] | |
) | |
for idx, mask in enumerate(masks): | |
mask.save(f"{instance_data_root}/{idx}.mask.png") | |
break | |
for idx in range(len(self.instance_images_path)): | |
self.mask_path.append(f"{instance_data_root}/{idx}.mask.png") | |
self.num_instance_images = len(self.instance_images_path) | |
self.token_map = token_map | |
self.use_template = use_template | |
if use_template is not None: | |
self.templates = TEMPLATE_MAP[use_template] | |
self._length = self.num_instance_images | |
self.h_flip = h_flip | |
self.image_transforms = transforms.Compose( | |
[ | |
transforms.Resize( | |
size, interpolation=transforms.InterpolationMode.BILINEAR | |
) | |
if resize | |
else transforms.Lambda(lambda x: x), | |
transforms.ColorJitter(0.1, 0.1) | |
if color_jitter | |
else transforms.Lambda(lambda x: x), | |
transforms.CenterCrop(size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
self.blur_amount = blur_amount | |
def __len__(self): | |
return self._length | |
def __getitem__(self, index): | |
example = {} | |
instance_image = Image.open( | |
self.instance_images_path[index % self.num_instance_images] | |
) | |
if not instance_image.mode == "RGB": | |
instance_image = instance_image.convert("RGB") | |
example["instance_images"] = self.image_transforms(instance_image) | |
if self.train_inpainting: | |
( | |
example["instance_masks"], | |
example["instance_masked_images"], | |
) = _generate_random_mask(example["instance_images"]) | |
if self.use_template: | |
assert self.token_map is not None | |
input_tok = list(self.token_map.values())[0] | |
text = random.choice(self.templates).format(input_tok) | |
else: | |
text = self.captions[index % self.num_instance_images].strip() | |
if self.token_map is not None: | |
for token, value in self.token_map.items(): | |
text = text.replace(token, value) | |
print(text) | |
if self.use_mask: | |
example["mask"] = ( | |
self.image_transforms( | |
Image.open(self.mask_path[index % self.num_instance_images]) | |
) | |
* 0.5 | |
+ 1.0 | |
) | |
if self.h_flip and random.random() > 0.5: | |
hflip = transforms.RandomHorizontalFlip(p=1) | |
example["instance_images"] = hflip(example["instance_images"]) | |
if self.use_mask: | |
example["mask"] = hflip(example["mask"]) | |
example["instance_prompt_ids"] = self.tokenizer( | |
text, | |
padding="do_not_pad", | |
truncation=True, | |
max_length=self.tokenizer.model_max_length, | |
).input_ids | |
return example | |